# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import getpass
import logging
import os
import re
import subprocess
import time
from pathlib import Path
import paramiko
logging.basicConfig(
    format='%(asctime)s %(name)s: %(levelname)-8s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S',
)
PT_CKPT_PATTERN = r"checkpoint_\d+.mdl"
TF_CKPT_PATTERN = r"model.ckpt-\d+"
CKPT_PATTERN = f"({PT_CKPT_PATTERN})|({TF_CKPT_PATTERN})"
# don't copy a checkpoint unless it's been untouched for 2 minutes
CKPT_UNTOUCHED_THRESHOLD = 2 * 60
[docs]def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir_colo")
    parser.add_argument("--remote_host")
    parser.add_argument("--model_dir_aws")
    parser.add_argument(
        "--coarse_checkpoint_steps",
        type=int,
        help=(
            "The frequency with which checkpoints are saved with the "
            "intension of long term storage and analysis. Often this interval "
            "is coarser than the frequency with which checkpoints are saved "
            "for restart purposes, see '--keep_last_n_checkpoints'"
        ),
    )
    parser.add_argument(
        "--keep_last_n_checkpoints",
        type=int,
        help=(
            "How many checkpoints to keep on remote for restarts in adddition "
            "to those kept according to `coarse_checkpoint_steps` for long "
            "term storage and analysis"
        ),
    )
    parser.add_argument(
        "--polling_interval",
        type=int,
        default=60 * 5,
        help="How often to check for new events (in seconds)",
    )
    parser.add_argument(
        "--analyze_weights",
        action="store_true",
        help="Extract summaries from weights after copying to aws",
    )
    args = parser.parse_args()
    return args 
[docs]def exists_remote(remote_host, p):
    file_exists = subprocess.call(["ssh", remote_host, f"test -f {p}"]) == 0
    dir_exists = subprocess.call(["ssh", remote_host, f"test -d {p}"]) == 0
    return file_exists or dir_exists 
[docs]def ckpt_name_to_step_num(name):
    if re.fullmatch(PT_CKPT_PATTERN, name):
        return int(name[len("checkpoint_") : -len(".mdl")])
    elif re.fullmatch(TF_CKPT_PATTERN, name):
        return int(name[len("model.ckpt-") :])
    else:
        raise ValueError(
            f"attempted to extract step number from invalid checkpoint {name}"
        ) 
[docs]def maybe_copy_checkpoint(ckpt, args):
    ckpt_path = os.path.join(args.model_dir_colo, ckpt)
    logs_dir = os.path.join(args.model_dir_aws, "logs")
    step_num = ckpt_name_to_step_num(ckpt)
    did_something = False
    modified_time = subprocess.run(
        ["ssh", args.remote_host, "stat", ckpt_path, "-c", r"%Y"],
        capture_output=True,
        text=True,
    ).stdout
    modified_time = int(modified_time)
    # get time from remote machine to remove potential consistency
    # or time zone issues
    current_time = subprocess.run(
        ["ssh", args.remote_host, "date", r"+%s"],
        capture_output=True,
        text=True,
    ).stdout
    current_time = int(current_time)
    if current_time - modified_time > CKPT_UNTOUCHED_THRESHOLD:
        # wait a few minutes before copying checkpoints to avoid
        # copying partially written files
        did_something = True
        log_file_path = os.path.join(
            logs_dir, f"logs_process_checkpoint_{step_num}.out"
        )
        cmd = [
            "cbrun",
            "--",
            "sbatch",
            "-c4",
            "-o",
            log_file_path,
            "launch_checkpoint_copy.sh",
            args.model_dir_colo,
            args.model_dir_aws,
            ckpt,
            args.remote_host,
        ]
        result = subprocess.run(cmd, capture_output=True, text=True)
        slurm_id = result.stdout.split()[-1]
        logging.info(
            f"Launched copy and processing of checkpoint {ckpt} with "
            f"slurm id {slurm_id}."
        )
        # Queue up weight analysis to run after checkpoint copy
        framework = "pt" if re.fullmatch(PT_CKPT_PATTERN, ckpt) else "tf"
        aws_ckpt_path = os.path.join(args.model_dir_aws, ckpt)
        cmd = [
            "cbrun",
            "--",
            "sbatch",
            "-c8",
            "--open-mode=append",
            "-o",
            log_file_path,
            "-d",
            f"afterok:{slurm_id}",
            "write_weight_summaries.py",
            "--input_path",
            aws_ckpt_path,
            "--output_path",
            aws_ckpt_path + ".wt_summary.txt",
            "--framework",
            framework,
        ]
        if args.analyze_weights:
            result = subprocess.run(cmd, capture_output=True, text=True)
            summaries_slurm_id = result.stdout.split()[-1]
            logging.info(
                f"Queued weight analysis of checkpoint {ckpt} with "
                f"slurm id {summaries_slurm_id} to start after job "
                f"{slurm_id} finishes successfully."
            )
    return did_something 
[docs]def main():
    args = parse_args()
    if not os.path.exists(args.model_dir_aws):
        Path(args.model_dir_aws).mkdir(parents=True)
    params_coppied = False
    params_path = os.path.join(
        args.model_dir_colo, "train", "params_train.yaml"
    )
    logs_dir = os.path.join(args.model_dir_aws, "logs")
    if not os.path.exists(logs_dir):
        os.mkdir(logs_dir)
    processed_checkpoints = set(
        ckpt_name_to_step_num(f)
        for f in os.listdir(args.model_dir_aws)
        if re.fullmatch(CKPT_PATTERN, f)
    )
    ssh = paramiko.SSHClient()
    ssh.set_missing_host_key_policy(paramiko.client.AutoAddPolicy)
    ssh.connect(
        args.remote_host,
        username="lab",
        password=getpass.getpass(f"password for lab@{args.remote_host}: "),
    )
    sftp = ssh.open_sftp()
    while True:
        tick = time.time()
        # copy params files
        if not params_coppied and exists_remote(args.remote_host, params_path):
            logging.info(f"Copying params {args.remote_host}:{params_path}")
            sftp.get(
                params_path,
                os.path.join(args.model_dir_aws, "params_train.yaml"),
            )
            params_coppied = True
        # copy checkpoints
        all_ckpts = [
            f
            for f in sftp.listdir(args.model_dir_colo)
            if re.fullmatch(CKPT_PATTERN, f)
        ]
        all_ckpts.sort(key=ckpt_name_to_step_num)
        for i, ckpt in enumerate(reversed(all_ckpts)):
            step_num = ckpt_name_to_step_num(ckpt)
            ckpt_path = os.path.join(args.model_dir_colo, ckpt)
            if step_num in processed_checkpoints:
                continue
            elif step_num % args.coarse_checkpoint_steps == 0:
                success = maybe_copy_checkpoint(ckpt, args)
                if success:
                    processed_checkpoints.add(step_num)
            elif (
                args.keep_last_n_checkpoints is not None
                and i >= args.keep_last_n_checkpoints
            ):
                logging.info(f"Removing remote checkpoint {ckpt_path}")
                sftp.remove(ckpt_path)
        tock = time.time()
        elapsed = tock - tick
        time.sleep(max(args.polling_interval - elapsed, 0)) 
if __name__ == "__main__":
    main()