Source code for cerebras.modelzoo.cli.model_info_cli

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Cerebras ModelZoo Model Information CLI Tool"""

import argparse
import json
import pydoc

from cerebras.modelzoo.cli.utils import MZ_CLI_NAME


[docs]class ModelInfoCLI: def __init__(self): parser = argparse.ArgumentParser() self.configure_parser(parser) args = parser.parse_args() args.func(args) @staticmethod def epilog(): return ( f"Use `{MZ_CLI_NAME} model -h` to learn how to query and investigate available models. " f"See below for some basic examples.\n\n" f"List all models:\n" f" $ {MZ_CLI_NAME} model list\n\n" f"Get additional information on gpt2:\n" f" $ {MZ_CLI_NAME} model info gpt2\n\n" f"Get details on all configuration parameters for gpt2:\n" f" $ {MZ_CLI_NAME} model describe gpt2\n\n" f"For more information on ModelZoo models and how they are used, see: " f"https://docs.cerebras.net/en/latest/wsc/Model-zoo/MZ-overview.html" ) @staticmethod def configure_parser(parser): from cerebras.modelzoo.cli.utils import get_table_parser parent_parser = get_table_parser() subparsers = parser.add_subparsers(dest="cmd", required=True) list_parser = subparsers.add_parser( "list", parents=[parent_parser], add_help=False, help="Lists available models.", ) list_parser.set_defaults(func=ModelInfoCLI.model_list) info_parser = subparsers.add_parser( "info", parents=[parent_parser], add_help=False, help="Gives a high level summary of a model and its supported components.", ) info_parser.add_argument( "model", default=None, help="Registered model name to display information on.", ) info_parser.set_defaults(func=ModelInfoCLI.model_info) describe_parser = subparsers.add_parser( "describe", parents=[parent_parser], add_help=False, help="Provides detailed infomation about a given model.", ) describe_parser.add_argument( "model", default=None, help="Registered model name to display information on.", ) describe_parser.set_defaults(func=ModelInfoCLI.model_describe) @staticmethod def model_list(args): from tabulate import tabulate all_models = ModelInfoCLI._list_models() if args.json: print(json.dumps(all_models)) else: table = [[model] for model in all_models] headers = ["Available models in ModelZoo"] table_out = tabulate(table, headers=headers, tablefmt="fancy_grid") if args.no_pager: print(table_out) else: pydoc.pager(table_out) @staticmethod def model_info(args): from tabulate import tabulate model_name = args.model model_path = ModelInfoCLI._get_model_path(model_name) model_configs = ModelInfoCLI._get_model_configs(model_name) model_dataprocessors = ModelInfoCLI._get_model_dataprocessors( model_name ) if args.json: json_dict = { "Name": model_name, "Path": model_path, "Configs": model_configs, "Dataprocessors": model_dataprocessors, } print(json.dumps(json_dict)) else: table = [ ["Name", model_name], # TODO: add description ["Path", model_path], ["Configs", "\n".join(model_configs)], ["Dataprocessors", "\n".join(model_dataprocessors)], ] table_out = tabulate(table, tablefmt="fancy_grid") if args.no_pager: print(table_out) else: pydoc.pager(table_out) @staticmethod def model_describe(args): from tabulate import tabulate from cerebras.modelzoo.config import ( create_config_class, describe_fields, ) from cerebras.modelzoo.registry import registry model_name = args.model model_cls = registry.get_model_class(model_name) model_cfg = ( create_config_class(model_cls).model_fields["config"].annotation ) fields = describe_fields(model_cfg) if args.json: print(json.dumps(fields)) else: header = list(fields[0].keys()) table = [list(field.values()) for field in fields] table_out = tabulate(table, headers=header, tablefmt="fancy_grid") if args.no_pager: print(table_out) else: pydoc.pager(table_out) @staticmethod def _get_model_path(model): from cerebras.modelzoo.registry import registry return str(registry.get_model_path(model)) @staticmethod def _get_model_configs(model): from cerebras.modelzoo.registry import registry model_path = registry.get_model_path(model) config_list = list(model_path.rglob(f"configs/*.yaml")) # handle edge cases if not config_list: config_list = list(model_path.rglob(f"**/*.yaml")) if not config_list: config_list = list(model_path.parent.rglob(f"**/*.yaml")) # can be replaced by str.removeprefix when upgrading to Python 3.9 def remove_prefix(text, prefix): if text.startswith(prefix): return text[len(prefix) :] return text return [remove_prefix(config.stem, "params_") for config in config_list] @staticmethod def _get_model_dataprocessors(model): from cerebras.modelzoo.registry import registry return [ dp.rsplit(".", 1)[-1] for dp in registry.get_model(model).data_processor_paths ] @staticmethod def _list_models(): from cerebras.modelzoo.registry import registry return registry.get_model_names()
if __name__ == '__main__': ModelInfoCLI()