Source code for cerebras.modelzoo.trainer.callbacks.model_stats

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re
from typing import List, Optional, Tuple

import pandas as pd
from tabulate import tabulate

from cerebras.modelzoo.trainer.callbacks import Callback


[docs]class CountParams(Callback): """Callback that runs on model setup for counting the number of parameters in a network. Along with printing the total number of parameters, it also prints out a table which shows the relative contribution (%) that each parameter has to the total count. Additionally, parameters can be grouped together to better see the relative contributions. For example, the following groups parameters across layers together using regex style search & replace: callbacks: - CountParams: search_and_replace: [["\\.layers\\.\\d+\\.", ".grouped_layers."]] ╒═══════════════════════════════════════════════════════════════════════════════╤══════════════╤═══════╕ │ Modules │ Parameters │ % │ ╞═══════════════════════════════════════════════════════════════════════════════╪══════════════╪═══════╡ │ model.embedding_layer.word_embeddings.weight │ 6,432,896 │ 93.96 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.embedding_layer.position_embeddings.embed.weight │ 16,384 │ 0.24 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.ln_f.weight │ 128 │ 0.00 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.ln_f.bias │ 128 │ 0.00 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.self_attn.proj_q_dense_layer.weight │ 32,768 │ 0.48 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.self_attn.proj_q_dense_layer.bias │ 256 │ 0.00 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.self_attn.proj_k_dense_layer.weight │ 32,768 │ 0.48 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.self_attn.proj_k_dense_layer.bias │ 256 │ 0.00 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.self_attn.proj_v_dense_layer.weight │ 32,768 │ 0.48 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.self_attn.proj_v_dense_layer.bias │ 256 │ 0.00 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.self_attn.proj_output_dense_layer.weight │ 32,768 │ 0.48 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.self_attn.proj_output_dense_layer.bias │ 256 │ 0.00 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.norm1.weight │ 256 │ 0.00 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.norm1.bias │ 256 │ 0.00 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.norm3.weight │ 256 │ 0.00 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.norm3.bias │ 256 │ 0.00 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.ffn.ffn.0.linear_layer.weight │ 131,072 │ 1.91 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.ffn.ffn.0.linear_layer.bias │ 1,024 │ 0.01 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.ffn.ffn.1.linear_layer.weight │ 131,072 │ 1.91 │ ├───────────────────────────────────────────────────────────────────────────────┼──────────────┼───────┤ │ model.transformer_decoder.all_layers.ffn.ffn.1.linear_layer.bias │ 256 │ 0.00 │ ╘═══════════════════════════════════════════════════════════════════════════════╧══════════════╧═══════╛ """ def __init__( self, search_and_replace: Optional[List[Tuple[str, str]]] = None ): """ Args: search_and_replace: An optional list of search & replace to apply to parameter names. Each search & replace is a tuple containing a regex string for searching and a corresponding replacement string. For example, you can "group" parameters together across layers by using \.layers\.\d+\. for search and replace with "grouped_layers" """ self._search_and_replace = search_and_replace
[docs] def setup(self, trainer): output, df = self.get_table(trainer.model) trainer.log_metrics(parameter_counts=df) logging.info(output)
[docs] def get_table(self, model): self.total_params, self.total_trainable_params, parameter_counts = ( self.get_parameter_counts(model, self._search_and_replace) ) header = ["Modules", "Parameters", "%"] table = [ [name, float(count), float(count) / self.total_params * 100] for name, count in parameter_counts.items() ] df = pd.DataFrame(table, columns=header) out = ( "\n" + tabulate( table, header, tablefmt="fancy_grid", floatfmt=(",.0f", ",.0f", ".2f"), ) + "\n" ) out += f"\nTotal Params (including frozen): {self.total_params:,}" out += f"\nTotal Trainable Params: {self.total_trainable_params:,}" return out, df
[docs] def get_parameter_counts(self, model, search_and_replace=None): parameter_counts = {} total_params = 0 total_trainable_params = 0 for name, parameter in model.named_parameters(): count = parameter.numel() total_params += count if not parameter.requires_grad: continue total_trainable_params += count group_name = name if search_and_replace: for search_regex, replace_str in search_and_replace: match = re.search(search_regex, group_name) if match: group_name = ( group_name[: match.start()] + replace_str + group_name[match.end() :] ) if group_name not in parameter_counts: parameter_counts[group_name] = 0 parameter_counts[group_name] += count return total_params, total_trainable_params, parameter_counts