# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# isort: off
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../"))
# isort: on
"""
Usage: 
torchrun --nnodes 1 --nproc_per_node 4 \
modelzoo/vision/pytorch/dit/sample_generator_dit.py \
--model_ckpt_path <trained_diffusion_ckpt path> \
--vae_ckpt_path <vae_ckpt_path> \
--params modelzoo/vision/pytorch/dit/configs/params_dit_xlarge_patchsize_2x2.yaml \
--sample_dir=modelzoo/vision/pytorch/dit/sample_dir \
--num_fid_samples=50000
"""
import argparse
import logging
import cerebras_pytorch as cstorch
from modelzoo.vision.pytorch.dit.layers.vae.VAEModel import AutoencoderKL
from modelzoo.vision.pytorch.dit.model import DiTModel
from modelzoo.vision.pytorch.dit.sample_generator import SampleGenerator
[docs]class DiTSampleGenerator(SampleGenerator):
[docs]    def __init__(
        self,
        model_ckpt_path,
        vae_ckpt_path,
        params,
        sample_dir,
        seed,
        num_fid_samples=50000,
        per_gpu_batch_size=None,
    ):
        """
        Class for DiT model sample generation
        Args:
            model_ckpt_path (str): Path to pretrained diffusion model checkpoint
            vae_ckpt_path (str): Path to pretrained VAE model checkpoint
            params (str): Path to yaml containing model params
            sample_dir (str): Path to folder where generated images 
                and npz file to be stored
            seed (int): Seed for random generation process
            num_fid_samples (int): Number of images to be generated
            per_gpu_batch_size (int): Per gpu batch size, 
                command line input overrides that in yaml if provided.
        """
        super().__init__(
            model_ckpt_path,
            vae_ckpt_path,
            params,
            sample_dir,
            seed,
            num_fid_samples,
            per_gpu_batch_size,
        ) 
[docs]    def create_diffusion_model(self, params, model_ckpt_path, use_cfg, device):
        """
        Initialize DiT model and load ckpt if provided
        
        Args:
            params (dict): params to be passed to DiT model initilization
            model_ckpt_path (str): Path to model checkpoint without VAE
            use_cfg (bool): If True, apply classifier free guidance on inputs
                i.e select the appropriate forward fcn based on `use_cfg`
            device (str): Target device for model
        
        Returns:
            dit_model (nn.Module): DiT model 
            fwd_fn (Callable) : Forward fcn to be used by 
                pipeline object for sampling
        """
        # Initialize DiT Model
        dit_model = DiTModel(params)
        # Load checkpoint
        if model_ckpt_path:
            _dit_dict = cstorch.load(model_ckpt_path)
            dit_model.load_state_dict(_dit_dict["model"])
            logging.info(f"Initializing DiT model with {model_ckpt_path}")
        else:
            logging.info(f"Initializing DiT model with random weights")
        dit_model = dit_model.model
        dit_model.to(device)
        # Select forward fcn
        if use_cfg:
            fwd_fn = dit_model.forward_dit_with_cfg
        else:
            fwd_fn = dit_model.forward_dit
        return dit_model, fwd_fn 
[docs]    def create_vae_model(self, vae_params, vae_ckpt_path, device):
        """
        Initialize VAE model and load ckpt if provided
        Args:
            vae_params (dict): params to initialize VAE model
            vae_ckpt_path (str) : Path to VAE model checkpoint
            device (str): Target device for model
        
        Returns:
            vae_model (nn.Module): VAE model for decoding
        """
        # Initialize VAE model
        vae_model = AutoencoderKL
        vae_model = vae_model(**vae_params)
        # Load checkpoint
        if vae_ckpt_path:
            _vae_dict = cstorch.load(vae_ckpt_path)
            vae_model.load_state_dict(_vae_dict)
            logging.info(f"Initializing VAE model with {vae_ckpt_path}")
        else:
            logging.info(f"Initializing VAE model with random weights")
        vae_model.to(device)
        return vae_model  
[docs]def get_parser_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument(
        "--model_ckpt_path",
        type=str,
        default=None,
        help="Optional path to a diffusion model checkpoint",
    )
    parser.add_argument(
        "--vae_ckpt_path",
        type=str,
        default=None,
        help="Optional VAE model checkpoint path",
    )
    parser.add_argument(
        "--params",
        type=str,
        required=True,
        help="Path to params to initialize Diffusion model and VAE models",
    )
    parser.add_argument(
        "--num_fid_samples",
        type=int,
        default=50000,
        help="number of samples to generate",
    )
    parser.add_argument(
        "--sample_dir",
        type=str,
        required=True,
        help="Directory to store generated samples",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        required=False,
        default=None,
        help="per-gpu batch size for forward pass",
    )
    parser.add_argument(
        "--create_grid",
        action="store_true",
        required=False,
        help="If passed, create a grid from images generated",
    )
    args = parser.parse_args()
    return args 
if __name__ == "__main__":
    args = get_parser_args()
    sample_gen = DiTSampleGenerator(
        args.model_ckpt_path,
        args.vae_ckpt_path,
        args.params,
        args.sample_dir,
        args.seed,
        args.num_fid_samples,
        args.batch_size,
    )
    sample_gen.run()
    if args.create_grid:
        import math
        from modelzoo.vision.pytorch.dit.display_images import display_images
        logging.info(f"Creating grid from samples generated....")
        nrow = math.ceil(math.sqrt(args.num_fid_samples))
        display_images(folder_path=args.sample_dir, nrow=nrow)