import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable #https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py class FocalLoss(nn.Module): def __init__(self, gamma=0, alpha=None, size_average=True): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha]) if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) self.size_average = size_average def forward(self, input, target): if input.dim() > 2: input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W input = input.transpose(1, 2) # N,C,H*W => N,H*W,C input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C target = target.view(-1, 1) logpt = F.log_softmax(input) logpt = logpt.gather(1, target) logpt = logpt.view(-1) pt = Variable(logpt.data.exp()) if self.alpha is not None: if self.alpha.type() != input.data.type(): self.alpha = self.alpha.type_as(input.data) at = self.alpha.gather(0, target.data.view(-1)) logpt = logpt * Variable(at) loss = -1 * (1 - pt)**self.gamma * logpt if self.size_average: return loss.mean() else: return loss.sum() class BinaryFocalLoss(nn.Module): def __init__(self, weight=None, gamma=2., alpha=1, reduction='mean'): nn.Module.__init__(self) self.weight = weight self.gamma = gamma self.reduction = reduction self.alpha = alpha def forward(self, inputs, targets): inputs = inputs.float() targets = targets.float() BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=self.reduction) pt = torch.exp(-BCE_loss) # prevents nans when probability 0 F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss return F_loss.mean()