Source code for vision.pytorch.input.classification.mixup

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
    Mixup and CutMix
    
    This is borrowed from the PyTorch repo:
    https://github.com/pytorch/vision/blob/main/references/classification/transforms.py
"""

import math

import torch
from torchvision.transforms import functional as F


[docs]class RandomMixup(torch.nn.Module): """Randomly apply Mixup to the provided batch and targets. The class implements the data augmentations as described in the paper `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_. Args: num_classes (int): number of classes used for one-hot encoding. p (float): probability of the batch being transformed. Default value is 0.5. alpha (float): hyperparameter of the Beta distribution used for mixup. Default value is 1.0. inplace (bool): boolean to make this transform inplace. Default set to False. """
[docs] def __init__( self, num_classes, p=0.5, alpha=1.0, inplace=False, ): super().__init__() if num_classes < 1: raise ValueError( f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" ) if alpha <= 0: raise ValueError("Alpha param can't be zero.") self.num_classes = num_classes self.p = p self.alpha = alpha self.inplace = inplace
[docs] def forward(self, batch, target): """ Args: batch (Tensor): Float tensor of size (B, C, H, W) target (Tensor): Integer tensor of size (B, ) Returns: Tensor: Randomly transformed batch. """ if batch.ndim != 4: raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") if target.ndim != 1: raise ValueError(f"Target ndim should be 1. Got {target.ndim}") if not batch.is_floating_point(): raise TypeError( f"Batch dtype should be a float tensor. Got {batch.dtype}." ) if target.dtype != torch.int64: raise TypeError( f"Target dtype should be torch.int64. Got {target.dtype}" ) if not self.inplace: batch = batch.clone() target = target.clone() if target.ndim == 1: target = torch.nn.functional.one_hot( target, num_classes=self.num_classes ).to(dtype=batch.dtype) if torch.rand(1).item() >= self.p: return batch, target # It's faster to roll the batch by one instead of shuffling it to create image pairs batch_rolled = batch.roll(1, 0) target_rolled = target.roll(1, 0) # Implemented as on mixup paper, page 3. lambda_param = float( torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] ) batch_rolled.mul_(1.0 - lambda_param) batch.mul_(lambda_param).add_(batch_rolled) target_rolled.mul_(1.0 - lambda_param) target.mul_(lambda_param).add_(target_rolled) return batch, target
def __repr__(self): s = ( f"{self.__class__.__name__}(" f"num_classes={self.num_classes}" f", p={self.p}" f", alpha={self.alpha}" f", inplace={self.inplace}" f")" ) return s
[docs]class RandomCutmix(torch.nn.Module): """Randomly apply Cutmix to the provided batch and targets. The class implements the data augmentations as described in the paper `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" <https://arxiv.org/abs/1905.04899>`_. Args: num_classes (int): number of classes used for one-hot encoding. p (float): probability of the batch being transformed. Default value is 0.5. alpha (float): hyperparameter of the Beta distribution used for cutmix. Default value is 1.0. inplace (bool): boolean to make this transform inplace. Default set to False. """
[docs] def __init__( self, num_classes, p=0.5, alpha=1.0, inplace=False, ): super().__init__() if num_classes < 1: raise ValueError( "Please provide a valid positive value for the num_classes." ) if alpha <= 0: raise ValueError("Alpha param can't be zero.") self.num_classes = num_classes self.p = p self.alpha = alpha self.inplace = inplace
[docs] def forward(self, batch, target): """ Args: batch (Tensor): Float tensor of size (B, C, H, W) target (Tensor): Integer tensor of size (B, ) Returns: Tensor: Randomly transformed batch. """ if batch.ndim != 4: raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") if target.ndim != 1: raise ValueError(f"Target ndim should be 1. Got {target.ndim}") if not batch.is_floating_point(): raise TypeError( f"Batch dtype should be a float tensor. Got {batch.dtype}." ) if target.dtype != torch.int64: raise TypeError( f"Target dtype should be torch.int64. Got {target.dtype}" ) if not self.inplace: batch = batch.clone() target = target.clone() if target.ndim == 1: target = torch.nn.functional.one_hot( target, num_classes=self.num_classes ).to(dtype=batch.dtype) if torch.rand(1).item() >= self.p: return batch, target # It's faster to roll the batch by one instead of shuffling it to create image pairs batch_rolled = batch.roll(1, 0) target_rolled = target.roll(1, 0) # Implemented as on cutmix paper, page 12 (with minor corrections on typos). lambda_param = float( torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0] ) H, W = F.get_image_size(batch) r_x = torch.randint(W, (1,)) r_y = torch.randint(H, (1,)) r = 0.5 * math.sqrt(1.0 - lambda_param) r_w_half = int(r * W) r_h_half = int(r * H) x1 = int(torch.clamp(r_x - r_w_half, min=0)) y1 = int(torch.clamp(r_y - r_h_half, min=0)) x2 = int(torch.clamp(r_x + r_w_half, max=W)) y2 = int(torch.clamp(r_y + r_h_half, max=H)) batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) target_rolled.mul_(1.0 - lambda_param) target.mul_(lambda_param).add_(target_rolled) return batch, target
def __repr__(self): s = ( f"{self.__class__.__name__}(" f"num_classes={self.num_classes}" f", p={self.p}" f", alpha={self.alpha}" f", inplace={self.inplace}" f")" ) return s