cerebras.modelzoo.losses.DPOLoss.DPOLoss#

class cerebras.modelzoo.losses.DPOLoss.DPOLoss(*args, **kwargs)[source]#

Bases: torch.nn.Module

DPO Loss :param beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.

We ignore the reference model as beta -> 0.

Parameters

reference_free (bool) – If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.

Methods

forward