# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
[docs]class NoiseGenerator(torch.nn.Module):
[docs]    def __init__(self, width, height, channels, num_diffusion_steps, seed=None):
        super(NoiseGenerator, self).__init__()
        self.height = height
        self.width = width
        self.channels = channels
        self.num_diffusion_steps = num_diffusion_steps 
[docs]    def forward(self, input, label):
        """
        Args:
            :param input : Float tensor of size (B, C, H, W).
            :param label : Int tensor of size (B, ).
        Returns:
            A dict corresponding to the noisy images, ground truth noises and
            the timesteps corresponding to the scheduled noise variance with 
            the following keys and shapes.
            
            "input": Tensor of shape (batch_size, C, H, W). This tensor is simply passed through.
            "label": Tensor of shape (batch_size, ) representing labels. This tensor is simply passed through.
            "diffusion_noise": Tensor of shape (batch_size, channels, height, width) 
                represents diffusion noise to be applied
            "timestep": Tensor of shape (batch_size, ) that indicates the timesteps for each diffusion sample
            "vae_noise": Tensor of shape (batch_size, latent_channels, latent_height, latent_width) 
                represents the noise sample to be used with reparametrization of VAE
        """
        if input.ndim != 4:
            raise ValueError(f"Samples ndim should be 4. Got {input.ndim}")
        # reshaping to (batch_size, 1, ..., 1) for broadcasting
        batch_size = input.shape[0]
        timestep = torch.randint(
            self.num_diffusion_steps, size=(batch_size,), dtype=label.dtype
        )
        noise_shape = (batch_size, self.channels, self.height, self.width)
        diffusion_noise = torch.randn(noise_shape, dtype=input.dtype).to(
            input.device
        )
        vae_noise_shape = (batch_size, self.channels, self.height, self.width)
        vae_noise_sample = torch.randn(vae_noise_shape, dtype=input.dtype)
        return {
            "input": input,
            "label": label,
            "diffusion_noise": diffusion_noise,
            "timestep": timestep,
            "vae_noise": vae_noise_sample,
        } 
    def __repr__(self):
        return (
            f"{self.__class__.__name__}("
            f"schedule_name={self.schedule_name}"
            f", num_diffusion_steps={self.num_diffusion_steps}"
            f")"
        ) 
[docs]class LabelDropout(torch.nn.Module):
[docs]    def __init__(self, dropout_prob, num_classes):
        super(LabelDropout, self).__init__()
        assert dropout_prob > 0
        self.dropout_prob = dropout_prob
        self.num_classes = num_classes 
    def token_drop(self, label):
        drop_ids = torch.rand(label.shape[0]) < self.dropout_prob
        return drop_ids
    def forward(self, image, label):
        drop_ids = self.token_drop(label)
        label = torch.where(
            drop_ids, torch.tensor(self.num_classes, dtype=label.dtype), label
        )
        return image, label