# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Provides DumpContext, a debug utility for dumping activations and gradients on
a CPU/GPU run, and setting up debug names for dumped WSE activations to be
automatically correlated.
"""
import functools
import os
import warnings
from collections import defaultdict
from contextlib import ContextDecorator
import numpy as np
import torch
import cerebras_pytorch as cstorch
from cerebras_pytorch.utils.nest import visit_torch_tensors
[docs]class DumpContext(ContextDecorator):
    """
    A debug utility context manager. When provided with a torch.nn.Module, the
    resulting context manager can be entered to enable dumping of all module
    forward and backward outputs to a npz, for comparing numerics between
    implementations.
    """
[docs]    def __init__(
        self, outdir: str, model: torch.nn.Module, buffer_steps: int = None
    ):
        """
        Sets up global module hoooks to either dump intermediate activations on
        CPU/GPU or name the traced tensors for correlating with debug dumps on
        CS2.
        The recursive name of the torch.nn.Module is memoized, and the output
        of FWD and BWD of each module is saved as keys in a .npz file.
        Args:
            outdir: Where to output dumps_{i}.npz
            model: root module to name its children
            buffer_steps: If given, flush to a new .npz file after this many
             steps
        """
        self._outdir = outdir
        os.makedirs(self._outdir, exist_ok=True)
        # The actual hook functions to install
        self._forward_pre_hook = None
        self._forward_hook = None
        self._backward_hook = None
        self._full_backward_hook = None
        self.setup_hooks(model)
        # Any installed hooks, set during enable_collection()
        self._module_hooks = []
        self._call_counter = {}
        self._buffer_steps = buffer_steps
        self._flush_count = 0
        self._buffer = defaultdict(list) 
    def __enter__(self):
        self.enable_collection()
        return self
    def __del__(self):
        self.flush()
    def __exit__(self, *exc):
        self.disable_collection()
        # Check if we need to flush by using the first buffer's size as a
        # proxy for how many steps we've captured.
        if self._buffer_steps and self._buffer:
            first_buffer = next(iter(self._buffer))
            if len(first_buffer) >= self._buffer_steps:
                self.flush()
[docs]    def setup_hooks(self, model):
        """
        Define hooking functions on the given torch.nn.Module, but don't
        install them.
        Args:
            model: torch.nn.Module that serves as the root for recursive names
        """
        if cstorch.use_cs():
            # Not enabled for CSX, dumping only works on CPU/GPU
            return
        cstorch.add_debug_name(model)
        # Helpers for hooks
        def get_name(module, counter_increment=0):
            name = cstorch.get_debug_name(module)
            def_counter = 0 if counter_increment >= 0 else 1
            counter = self._call_counter.setdefault(name, def_counter)
            self._call_counter[name] += counter_increment
            if counter != def_counter:
                name = f"{name}.call{counter}"
            return name
        def save_tensors(top_scope, tensors):
            for scope, tensor in visit_torch_tensors(tensors, scope=top_scope):
                tensor = tensor.detach().to("cpu").clone()
                if tensor.dtype == torch.bfloat16:
                    warnings.warn(
                        "Encountered bfloat16 tensor in summary collection. "
                        "Numpy does not natively support bfloat16, so any "
                        "torch.bfloat16 tensors will be saved as np.float32."
                    )
                    tensor = tensor.float()
                name = ".".join(scope)
                numpy = tensor.numpy()
                print(scope)
                self._buffer[name].append(numpy)
        # pylint: disable=redefined-builtin
        def save_output(key, module, input, output):
            """
            Saves to numpy arrays in the output directory.
            """
            counter_increment = 1
            if key == "bwd":
                counter_increment = -1
                # hook args are `grad_input, grad_output`, where grad_input
                # is the _gradient_ of the module's input i.e. the output
                # of the backward pass and the more interesting value to
                # dump. This way, the dump named `module.fwd` is the output
                # of the forward pass (i.e. txact), and `module.bwd` is the
                # output of the backward pass (i.e. txdelta) for the
                # corresponding kernel
                output = input
            name = get_name(module, counter_increment)
            save_tensors([name, key], output)
        self._forward_hook = functools.partial(save_output, "fwd")
        self._full_backward_hook = functools.partial(save_output, "bwd")
        # Add hook capturing parameter gradients
        def param_grad_hook(module, input):
            module_name = get_name(module)
            for name, param in module.named_parameters(recurse=False):
                if param.requires_grad and not hasattr(param, "dump_context"):
                    param.dump_context = True
                    scope = [module_name, "bwd", name]
                    param.register_hook(functools.partial(save_tensors, scope))
        self._forward_pre_hook = param_grad_hook 
[docs]    def enable_collection(self):
        """
        Install the hooks defined during `setup_hooks`, enabling the
        collection of the dumps.
        """
        def install_if_set(hook):
            hook_fn = getattr(self, f"_{hook}_hook")
            if hook_fn:
                register_fn = f"register_module_{hook}_hook"
                return getattr(torch.nn.modules.module, register_fn)(hook_fn)
            return None
        hooks = ("forward_pre", "forward", "backward", "full_backward")
        self._module_hooks = [install_if_set(hook) for hook in hooks]
        # Clear call counters
        self._call_counter = {} 
[docs]    def disable_collection(self):
        """
        Uninstall the hooks installed during `enable_collection`, disabling
        further dump collection.
        """
        for hook in self._module_hooks:
            if hook:
                hook.remove()
        self._module_hooks = [] 
[docs]    def flush(self):
        """
        Write all dump buffers out to disk.
        """
        if not self._buffer:
            return
        if self._flush_count:
            outfile = f"act_dumps_{self._flush_count}.npz"
        else:
            outfile = "act_dumps.npz"
        np.savez(
            os.path.join(self._outdir, outfile),
            **{key: np.stack(values) for key, values in self._buffer.items()},
        )
        self._buffer.clear()
        self._flush_count += 1