# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelzoo.common.pytorch.run_utils import half_dtype_instance
[docs]class Dice:
[docs] def __init__(
self,
num_classes: int,
to_onehot_y: bool = True,
to_onehot_x: bool = False,
use_softmax: bool = True,
use_argmax: bool = False,
include_background: bool = False,
input_shape=None,
use_native_onehot: bool = True,
):
self.num_classes = num_classes
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.to_onehot_x = to_onehot_x
self.use_softmax = use_softmax
self.use_argmax = use_argmax
self.use_native_onehot = use_native_onehot
self.smooth_nr = 0.0
self.smooth_dr = 1e-6
self.include_background = include_background
self.input_shape = None
self.bg_mask = None
if not self.include_background:
if input_shape:
self.input_shape = input_shape
else:
raise ValueError(
"must supply input shape when include_background=False"
)
def _create_background_mask(self, device, dtype, ish, chanx):
from modelzoo.common.pytorch import cb_model as cm
z_shape = ish[0:chanx] + [1] + ish[chanx + 1 :] # [N,1,D,H,W]
o_shape = (
ish[0:chanx] + [ish[chanx] - 1] + ish[chanx + 1 :]
) # [N,C-1,D,H,W]
zeros = torch.zeros(z_shape, device=device, dtype=dtype)
ones = torch.ones(o_shape, device=device, dtype=dtype)
weights = torch.cat(
(zeros, ones), chanx
) # [N,C,D,H,W] w/ first ch 0'ed
if cm.use_cs():
bg_mask = cm.make_constant(weights)
else:
bg_mask = weights.to(device)
return bg_mask
def __call__(self, prediction, target):
target = torch.unsqueeze(target, 1)
channel_axis = 1
reduce_axis = list(range(2, len(prediction.shape)))
num_pred_ch = prediction.shape[channel_axis]
if self.use_softmax:
prediction = torch.softmax(prediction, dim=channel_axis)
elif self.use_argmax:
prediction = torch.argmax(prediction, dim=channel_axis)
if self.to_onehot_y:
target = to_one_hot(
target, channel_axis, self.num_classes, self.use_native_onehot
)
if self.to_onehot_x:
prediction = to_one_hot(
prediction,
channel_axis,
self.num_classes,
self.use_native_onehot,
)
if not self.include_background:
if self.bg_mask is None:
self.bg_mask = self._create_background_mask(
target.device,
prediction.dtype,
self.input_shape,
channel_axis,
)
assert (
num_pred_ch > 1
), f"To exclude background the prediction needs more than one channel. Got {num_pred_ch}."
target = target * self.bg_mask
prediction = prediction * self.bg_mask
assert (
target.shape == prediction.shape
), f"Target and prediction shape do not match. Target: ({target.shape}), prediction: ({prediction.shape})."
intersection = torch.sum(target * prediction, dim=reduce_axis)
target_sum = torch.sum(target, dim=reduce_axis)
prediction_sum = torch.sum(prediction, dim=reduce_axis)
res = (2.0 * intersection + self.smooth_nr) / (
target_sum + prediction_sum + self.smooth_dr
)
return res
[docs]def to_one_hot(array, channel_axis, num_classes, use_native_onehot):
if len(array.shape) >= 5:
array = torch.squeeze(array, dim=channel_axis)
if use_native_onehot:
array = F.one_hot(array.long(), num_classes).float()
else:
init = torch.zeros(
array.shape + (num_classes,),
device=array.device,
dtype=half_dtype_instance.half_dtype,
)
array = init.scatter_(-1, array.long().unsqueeze(-1), 1.0).float()
array = array.permute(0, 4, 1, 2, 3)
return array
[docs]class DiceCELoss(nn.Module):
[docs] def __init__(
self, num_classes, input_shape, include_background, wc=0.5, wd=0.5,
):
super(DiceCELoss, self).__init__()
self.dice = Dice(
num_classes=num_classes,
include_background=include_background,
input_shape=input_shape,
)
self.cross_entropy = nn.CrossEntropyLoss()
self.wc = wc
self.wd = wd
if not include_background:
self.mean_correction = torch.tensor(
num_classes / (num_classes - 1), dtype=torch.float32,
)
else:
self.mean_correction = torch.tensor(1.0, dtype=torch.float32,)
self.one_const = torch.tensor(1.0, dtype=torch.float32,)
[docs] def forward(self, outputs, labels):
ce = self.cross_entropy(outputs, labels)
dc = self.mean_correction * torch.mean(self.dice(outputs, labels))
loss = self.wc * ce + self.wd * (self.one_const - dc)
return loss
[docs]class DiceScore:
[docs] def __init__(
self,
to_onehot_y: bool = True,
to_onehot_x: bool = True,
use_argmax: bool = False, # argmax already done in model
use_softmax: bool = False,
include_background: bool = False,
):
self.dice = Dice(
to_onehot_y=to_onehot_y,
to_onehot_x=to_onehot_x,
use_softmax=use_softmax,
use_argmax=use_argmax,
include_background=include_background,
)
def __call__(self, labels=None, predictions=None, weights=None):
return torch.mean(self.dice(predictions, labels), dim=0)