Source code for cerebras.modelzoo.data.vision.segmentation.preprocessing_utils
# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torchvision import transforms
[docs]def normalize_tensor_transform(img, normalize_data_method):
    """
    Function to normalize img
    :params img: Input torch.Tensor of any shape
    :params normalize_data_method: One of
        "zero_centered"
        "zero_one"
        "standard_score"
    """
    if normalize_data_method is None:
        pass
    elif normalize_data_method == "zero_centered":
        img = torch.div(img, 127.5) - 1
    elif normalize_data_method == "zero_one":
        img = torch.div(img, 255.0)
    elif normalize_data_method == "standard_score":
        img = (img - img.mean()) / img.std()
    else:
        raise ValueError(
            f"Invalid arg={normalize_data_method} passed to `normalize_data_method`"
        )
    return img
[docs]def adjust_brightness_transform(img, p, delta):
    """
    Function equivalent to `tf.image.adjust_brightness`,
    but executed probabilistically.
    :params img: Input torch.Tensor of any shape
    :params p: Integer representing probability
    :params delta: Float value representing the value
        by which img Tensor is increased or decreased.
    """
    if (torch.rand(1) > p).item():
        img = torch.add(img, delta)
    return img
[docs]def rotation_90_transform(img, num_rotations):
    """
    Function equivalent to `tf.image.rot90`
    Rotates img in counter clockwise direction
    :params img: torch.Tensor of shape (C, H, W) or (H, W)
    :params num_rotations: int value representing
        number of counter clock-wise rotations of img
    """
    if len(img.shape) == 3:
        # If image of type (C, H, W), rotate along H, W
        # Rotate in counter-clockwise direction
        dims = [1, 2]
    else:
        dims = [0, 1]
    img = torch.rot90(img, k=num_rotations, dims=dims)
    return img
[docs]def resize_image_with_crop_or_pad_transform(img, target_height, target_width):
    """
    Function equivalent to `tf.image.resize_with_crop_or_pad`
    :params img: torch.Tensor of shape (C, H, W) or (H, W)
    :params target_height: int value representing output image height
    :params target_width: int value representing output image width
    :returns torch.Tensor of shape (C, target_height, target_width)
    """
    def _pad_image(img):
        """
        Pad image till it reaches target_height and target_width
        """
        img_shape = img.shape
        img_width = img_shape[-1]
        img_height = img_shape[-2]
        lft_rgt_pad = max((target_width - img_width) // 2, 0)
        top_bot_pad = max((target_height - img_height) // 2, 0)
        excess_right_pad = target_width - img_width - 2 * lft_rgt_pad
        excess_bot_pad = target_height - img_height - 2 * top_bot_pad
        pad = [
            lft_rgt_pad,
            lft_rgt_pad + excess_right_pad,
            top_bot_pad,
            top_bot_pad + excess_bot_pad,
        ]
        img = torch.nn.functional.pad(img, pad)
        return img
    def _crop_image(img):
        img_shape = img.shape
        # Crop only when necessary. CenterCrop pads if
        # crop dimensions are greater, hence taking min.
        crop_height = min(img_shape[-2], target_height)
        crop_width = min(img_shape[-1], target_width)
        img = transforms.CenterCrop((crop_height, crop_width))(img)
        return img
    cropped_img = _crop_image(img)
    padded_img = _pad_image(cropped_img)
    assert padded_img.shape[-1] == target_width
    assert padded_img.shape[-2] == target_height
    return padded_img
[docs]def tile_image_transform(img, target_height, target_width):
    """
    Function to tile image to tgt_height and target_width
    If target_height < image_height: image is not tiled in this dimension.
    If target_width < image_width: image is not tiled in this dimension.
    :params img: input torch.Tensor of shape (C, H, W)
    :params target_height: int value representing output tiled image height
    :params target_width: int value representing output tiled image width
    :returns torch.Tensor of shape (C, target_height, target_width)
    """
    assert len(img.shape) == 3
    img_channels, img_height, img_width = img.shape
    tgt_img_shape = [img_channels, target_height, target_width]
    def _get_tiled_image(img, tgt_img_shape, axis):
        if tgt_img_shape[axis] <= img.shape[axis]:
            # No tiling since image already satisfies requirement
            return img
        else:
            diff = tgt_img_shape[axis] - img.shape[axis]
            q, r = divmod(diff, img.shape[axis])
            temp_img = img
            for _ in range(q):
                temp_img = torch.concat((img, temp_img), axis=axis)
            if r > 0:
                if axis == 1:
                    sliced_img = temp_img[:, :r, :]
                elif axis == 2:
                    sliced_img = temp_img[:, :, :r]
                else:
                    raise ValueError(
                        f"Incorrect value of {axis} passed. Valid integers are 1, 2"
                    )
                temp_img = torch.concat((temp_img, sliced_img), axis=axis)
            return temp_img
    v_tiled_img = _get_tiled_image(img, tgt_img_shape=tgt_img_shape, axis=1)
    tiled_img = _get_tiled_image(
        v_tiled_img, tgt_img_shape=tgt_img_shape, axis=2
    )
    return tiled_img