# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
[docs]def randn_tensor(
    shape: Union[Tuple, List],
    generator: Optional[
        Union[List["torch.Generator"], "torch.Generator"]
    ] = None,
    device: Optional["torch.device"] = None,
    dtype: Optional["torch.dtype"] = None,
    layout: Optional["torch.layout"] = None,
):
    """
    This is a helper function that allows to create random tensors 
    on the desired `device` with the desired `dtype`. When
    passing a list of generators one can seed each batched size individually. 
    If CPU generators are passed the tensor
    will always be created on CPU.
    """
    # device on which tensor is created defaults to device
    rand_device = device
    batch_size = shape[0]
    layout = layout or torch.strided
    device = device or torch.device("cpu")
    if generator is not None:
        gen_device_type = (
            generator.device.type
            if not isinstance(generator, list)
            else generator[0].device.type
        )
        if gen_device_type != device.type and gen_device_type == "cpu":
            rand_device = "cpu"
            if device != "mps":
                logger.info(
                    f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
                    f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
                    f" slighly speed up this function by passing a generator that was created on the {device} device."
                )
        elif gen_device_type != device.type and gen_device_type == "cuda":
            raise ValueError(
                f"Cannot generate a {device} tensor from a generator of type {gen_device_type}."
            )
    if isinstance(generator, list):
        shape = (1,) + shape[1:]
        latents = [
            torch.randn(
                shape,
                generator=generator[i],
                device=rand_device,
                dtype=dtype,
                layout=layout,
            )
            for i in range(batch_size)
        ]
        latents = torch.cat(latents, dim=0).to(device)
    else:
        latents = torch.randn(
            shape,
            generator=generator,
            device=rand_device,
            dtype=dtype,
            layout=layout,
        ).to(device)
    return latents 
def _emulate_chunk2_dim1(x):
    c = x.shape[1] // 2
    return x[:, 0:c, ...], x[:, c:, ...]
[docs]class DiagonalGaussianDistribution(object):
[docs]    def __init__(self, parameters, deterministic=False):
        self.parameters = parameters
        self.mean, self.logvar = _emulate_chunk2_dim1(parameters)
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
            self.var = self.std = torch.zeros_like(
                self.mean,
                device=self.parameters.device,
                dtype=self.parameters.dtype,
            ) 
    def sample(
        self, noise=None, generator: Optional[torch.Generator] = None
    ) -> torch.FloatTensor:
        # make sure sample is on the same device as the parameters and has same dtype
        if noise is None:
            noise = randn_tensor(
                self.mean.shape,
                generator=generator,
                device=self.parameters.device,
                dtype=self.parameters.dtype,
            )
        x = self.mean + self.std * noise
        return x
    def kl(self, other=None):
        if self.deterministic:
            return torch.Tensor([0.0])
        else:
            if other is None:
                return 0.5 * torch.sum(
                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
                    dim=[1, 2, 3],
                )
            else:
                return 0.5 * torch.sum(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var
                    - 1.0
                    - self.logvar
                    + other.logvar,
                    dim=[1, 2, 3],
                )
    def nll(self, sample, dims=[1, 2, 3]):
        if self.deterministic:
            return torch.Tensor([0.0])
        logtwopi = np.log(2.0 * np.pi)
        return 0.5 * torch.sum(
            logtwopi
            + self.logvar
            + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims,
        )
    def mode(self):
        return self.mean