Source code for cerebras.modelzoo.losses.dpr_loss

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torch import nn


[docs]class DPRLoss(nn.Module): def __init__(self, mutual_information, softmax_temperature): super(DPRLoss, self).__init__() self.loss_fn = nn.CrossEntropyLoss() self.mutual_information = mutual_information self.softmax_temperature = softmax_temperature
[docs] def forward(self, q2c_scores, labels, c2q_scores=None, context_labels=None): ''' Args: q2c_scores: question-to-context scores (batch_size, batch_size * num_context) labels: labels for question-to-context scores (batch_size, ) Optional Args: c2q_scores: context_to_question scores (batch_size, batch_size) context_labels: labels for context_to_question scores (batch_size, ) Please see the comment in dpr_model.py for more details. ''' num_context = int(q2c_scores.shape[1] / q2c_scores.shape[0]) # By setting mutual information to True, we are adding additional # context-to-question loss to the original question-to-context loss, # which makes the loss value scale different from the case # where mutual information is OFF. Therefore, we add a distribution_factor, # based on the number of contexts, to keep loss value at the same scale. distribution_factor = 1 / (1 + num_context) if not self.mutual_information: loss = self.loss_fn(q2c_scores / self.softmax_temperature, labels) else: question_to_context_loss = self.loss_fn( q2c_scores / self.softmax_temperature, labels ) context_to_question_loss = self.loss_fn( c2q_scores / self.softmax_temperature, context_labels ) loss = ( (1 - distribution_factor) * question_to_context_loss + distribution_factor * context_to_question_loss ) return loss