Source code for common.pytorch.perf_utils

# Copyright 2022 Cerebras Systems.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import dataclasses
import json
import os
from typing import List

from modelzoo.common.pytorch import cb_model as cm
from modelzoo.common.pytorch import cbtorch

# XLA counter keys used to track various metrics
KEY_COMPILE_TIME = "CerebrasCompileTimeMs"
KEY_PROGRAMMING_TIME = "CerebrasProgrammingTimeNs"
KEY_SYSTEM_PERF = "CerebrasSystemEstSamples"

[docs]@dataclasses.dataclass class PerfData: """Data structure for holding performance data. Args: total_samples: Total number of samples processes. total_time: Total time spent processing those samples. samples_per_sec: Throuput of processing those samples. compile_time: Time spent compiling the model. programming_time: Time spent programming the fabric. est_samples_per_sec: Estimated throughput based on compile and fabric. """ total_samples: int = 0 total_time: float = 0.0 samples_per_sec: float = 0.0 compile_time: float = 0.0 programming_time: float = 0.0 est_samples_per_sec: float = 0.0
[docs] def merge(self, other: PerfData): """Merge another `PerfData` instance into self. Args: other: The other `PerfData` instance to merge. """ self.total_samples += other.total_samples if self.total_time == 0.0: self.total_time = other.total_time self.compile_time = max(self.compile_time, other.compile_time) self.programming_time = max( self.programming_time, other.programming_time ) if self.est_samples_per_sec == 0.0: self.est_samples_per_sec = other.est_samples_per_sec else: assert ( self.est_samples_per_sec == other.est_samples_per_sec ), "Expected all fabric-based performance estimates to be identical" if self.total_time > 0: self.samples_per_sec = float(self.total_samples) / self.total_time else: self.samples_per_sec = 0.0
[docs] def throughput_dict(self) -> dict: return { key: getattr(self, key) for key in ("total_samples", "total_time", "samples_per_sec") }
def _get_optional_counter(name: str, default: int = 0) -> int: """Return XLA counter value by name. Args: name: Name of the XLA counter. default: Default value to return if XLA counter with that name does not exist. Defaults to 0. Returns: The counter value. """ counter = cm.metrics_counter_value(name) if counter is None: return default return counter
[docs]def collect_perf_data(tracker: cm.RateTracker): """Collect performance data from a run. Args: tracker: Tracker which contains performance data. Returns: A PerfData instance containing the perf data. """ pd = PerfData() pd.compile_time = _get_optional_counter(KEY_COMPILE_TIME) * 1e-3 pd.programming_time = _get_optional_counter(KEY_PROGRAMMING_TIME) * 1e-9 pd.est_samples_per_sec = _get_optional_counter(KEY_SYSTEM_PERF) pd.total_samples = tracker._partial_count + tracker._count pd.samples_per_sec = tracker.global_rate() if pd.samples_per_sec > 0: pd.total_time = float(pd.total_samples) / pd.samples_per_sec else: pd.total_time = 0.0 return pd
def _aggregate_perf_data(perf_all_ordinals: List[str]): """Aggregate performance data from multiple workers. Args: files: List of objects containing individual worker performance data. Returns: The aggregated performance metrics. """ aggregate = {} pd = PerfData() for data in perf_all_ordinals: worker_pd = PerfData(**json.loads(data)) if len(perf_all_ordinals) > 1: aggregate.setdefault("ordinals", []) aggregate["ordinals"].append(worker_pd.throughput_dict()) pd.merge(worker_pd) aggregate.update(dataclasses.asdict(pd)) return aggregate def _rendezvous(payload: str): """Perform a rendezvouz across workers and exchange payload. Args: payload: String data from each worker to exchange. Returns: The list of payloads passed by all ordinals, ordered by ordinal number. """ payloads = cm.rendezvous("save_individual_perf_data", payload=payload) if not payloads: # no mesh service (i.e., single ordinal) payloads = [payload] return payloads
[docs]def save_perf(outdir: str): """Utility method for saving performance metrics from a run. Args: outdir: Output directory to write performance files to. """ tracker = cbtorch.state().rate_tracker if tracker is None: # No performance data to save return perf_this_ordinal = collect_perf_data(tracker) # Sync across ordinals and receive the perf files perf_all_ordinals = _rendezvous( json.dumps(dataclasses.asdict(perf_this_ordinal)) ) # Aggregate perf data in master ordinal if cm.is_master_ordinal(): aggregate = _aggregate_perf_data(perf_all_ordinals) os.makedirs(outdir, exist_ok=True) with open(os.path.join(outdir, "performance.json"), "w") as fp: json.dump(aggregate, fp, sort_keys=True, indent=4)