# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
from modelzoo.common.pytorch.layers.FeedForwardNetwork import FeedForwardNetwork
from modelzoo.vision.pytorch.dit.layers.GaussianDiffusion import index
[docs]class TimestepEmbeddingLayer(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
[docs]    def __init__(
        self,
        num_diffusion_steps,
        hidden_size,
        frequency_embedding_size=256,
        nonlinearity="silu",
        kernel_initializer: str = "xavier_uniform",
        bias_initializer: str = "zeros",
    ):
        super().__init__()
        self.timestep_embedding = self.create_timestep_embedding(
            seq_len=num_diffusion_steps, dim=frequency_embedding_size
        )
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer
        self.ffn = FeedForwardNetwork(
            input_unit=frequency_embedding_size,
            layers_units=[hidden_size, hidden_size],
            layers_activation=[nonlinearity, None],
            use_bias=True,
            kernel_initializer=self.kernel_initializer,
            bias_initializer=self.bias_initializer,
        )
        # Initialize weights and bias
        self.__reset_parameters() 
    def reset_parameters(self):
        self.__reset_parameters()
    def __reset_parameters(self):
        self.ffn.reset_parameters()
[docs]    @staticmethod
    def create_timestep_embedding(seq_len, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        Slightly different than `EmbeddingLayer.create_fix_pos_embedding`.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        position = torch.arange(seq_len, dtype=torch.float32)
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period)
            * torch.arange(start=0, end=half, dtype=torch.float32)
            / half
        )
        args = position[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
        return torch.nn.Parameter(embedding, requires_grad=False) 
    def forward(self, t):
        t_freq = index(self.timestep_embedding, t)
        t_emb = self.ffn(t_freq)
        return t_emb