# flake8: noqa from __future__ import print_function import torch import torch.nn as nn import torch.nn.functional as F def SoftCE(predicted, target): logprobs = torch.nn.functional.log_softmax(predicted, dim=1) return -(target * logprobs).sum() / predicted.shape[0] class CrossEntropyLabelSmooth(nn.Module): """Cross entropy loss with label smoothing regularizer. Reference: Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. Equation: y = (1 - epsilon) * y + epsilon / K. Args: num_classes (int): number of classes. epsilon (float): weight. """ def __init__(self, num_classes, epsilon=0.1): super(CrossEntropyLabelSmooth, self).__init__() self.num_classes = num_classes self.epsilon = epsilon self.logsoftmax = nn.LogSoftmax(dim=1).cuda() def forward(self, inputs, targets): """ Args: inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) targets: ground truth labels with shape (num_classes) """ log_probs = self.logsoftmax(inputs) targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes loss = (- targets * log_probs).mean(0).sum() return loss class SoftEntropy(nn.Module): def __init__(self): super(SoftEntropy, self).__init__() self.logsoftmax = nn.LogSoftmax(dim=1) def forward(self, inputs, targets): log_probs = self.logsoftmax(inputs) loss = (- F.softmax(targets, dim=1).detach() * log_probs).mean(0).sum() return loss class QuadrupletLoss(torch.nn.Module): """ Quadruplet loss function. Builds on the Triplet Loss and takes 4 data input: one anchor, one positive and two negative examples. The negative examples needs not to be matching the anchor, the positive and each other. """ def __init__(self, margin1=2.0, margin2=1.0): super(QuadrupletLoss, self).__init__() self.margin1 = margin1 self.margin2 = margin2 def forward(self, anchor, positive, negative1, negative2): squarred_distance_pos = (anchor - positive).pow(2).sum(1) squarred_distance_neg = (anchor - negative1).pow(2).sum(1) squarred_distance_neg_b = (negative1 - negative2).pow(2).sum(1) quadruplet_loss = \ F.relu(self.margin1 + squarred_distance_pos - squarred_distance_neg) \ + F.relu(self.margin2 + squarred_distance_pos - squarred_distance_neg_b) return quadruplet_loss.mean() def simCLR_loss(features, temperature=0.07, device='cuda'): batch_size, n_views = features.shape[0], features.shape[1] labels = torch.cat([torch.arange(batch_size) for i in range(n_views)], dim=0) labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() labels = labels.to(device) # print(labels) if len(features.shape) > 3: features = features.contiguous().view(features.shape[0], features.shape[1], -1) features = torch.cat(torch.unbind(features, dim=1), dim=0) features = F.normalize(features, dim=1) similarity_matrix = torch.matmul(features, features.T) # assert similarity_matrix.shape == ( # self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size) # assert similarity_matrix.shape == labels.shape # discard the main diagonal from both: labels and similarities matrix mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device) labels = labels[~mask].view(labels.shape[0], -1) similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) # assert similarity_matrix.shape == labels.shape # select and combine multiple positives positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) # select only the negatives the negatives negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) logits = torch.cat([positives, negatives], dim=1) labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device) logits = logits / temperature return logits, labels class SupConLoss(nn.Module): """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. It also supports the unsupervised contrastive loss in SimCLR""" def __init__(self, temperature=0.07, contrast_mode='all', base_temperature=0.07): super(SupConLoss, self).__init__() self.temperature = temperature self.contrast_mode = contrast_mode self.base_temperature = base_temperature def forward(self, features, labels=None, mask=None): """Compute loss for model. If both `labels` and `mask` are None, it degenerates to SimCLR unsupervised loss: https://arxiv.org/pdf/2002.05709.pdf Args: features: hidden vector of shape [bsz, n_views, ...]. labels: ground truth of shape [bsz]. mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j has the same class as sample i. Can be asymmetric. Returns: A loss scalar. """ device = (torch.device('cuda') if features.is_cuda else torch.device('cpu')) if len(features.shape) < 3: raise ValueError('`features` needs to be [bsz, n_views, ...],' 'at least 3 dimensions are required') if len(features.shape) > 3: features = features.view(features.shape[0], features.shape[1], -1) batch_size = features.shape[0] if labels is not None and mask is not None: raise ValueError('Cannot define both `labels` and `mask`') elif labels is None and mask is None: mask = torch.eye(batch_size, dtype=torch.float32).to(device) elif labels is not None: labels = labels.contiguous().view(-1, 1) if labels.shape[0] != batch_size: raise ValueError('Num of labels does not match num of features') mask = torch.eq(labels, labels.T).float().to(device) else: mask = mask.float().to(device) contrast_count = features.shape[1] contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) if self.contrast_mode == 'one': anchor_feature = features[:, 0] anchor_count = 1 elif self.contrast_mode == 'all': anchor_feature = contrast_feature anchor_count = contrast_count else: raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) # compute logits anchor_dot_contrast = torch.div( torch.matmul(anchor_feature, contrast_feature.T), self.temperature) # for numerical stability logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) logits = anchor_dot_contrast - logits_max.detach() # tile mask mask = mask.repeat(anchor_count, contrast_count) # mask-out self-contrast cases logits_mask = torch.scatter( torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0 ) mask = mask * logits_mask # compute log_prob exp_logits = torch.exp(logits) * logits_mask log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) # compute mean of log-likelihood over positive mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) # loss loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos loss = loss.view(anchor_count, batch_size).mean() return loss class SoftSupConLoss(nn.Module): """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. It also supports the unsupervised contrastive loss in SimCLR""" def __init__(self, temperature=0.07, contrast_mode='all', base_temperature=0.07): super(SupConLoss, self).__init__() self.temperature = temperature self.contrast_mode = contrast_mode self.base_temperature = base_temperature def forward(self, features, labels=None, soft_mask=None): """Compute loss for model. If both `labels` and `mask` are None, it degenerates to SimCLR unsupervised loss: https://arxiv.org/pdf/2002.05709.pdf Args: features: hidden vector of shape [bsz, n_views, ...]. labels: ground truth of shape [bsz]. mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j has the same class as sample i. Can be asymmetric. Returns: A loss scalar. """ device = (torch.device('cuda') if features.is_cuda else torch.device('cpu')) if len(features.shape) < 3: raise ValueError('`features` needs to be [bsz, n_views, ...],' 'at least 3 dimensions are required') if len(features.shape) > 3: features = features.view(features.shape[0], features.shape[1], -1) batch_size = features.shape[0] if labels is not None and soft_mask is not None: raise ValueError('Cannot define both `labels` and `mask`') elif labels is None and soft_mask is None: soft_mask = torch.eye(batch_size, dtype=torch.float32).to(device) elif labels is not None: labels = labels.contiguous().view(-1, 1) if labels.shape[0] != batch_size: raise ValueError('Num of labels does not match num of features') soft_mask = torch.eq(labels, labels.T).float().to(device) else: soft_mask = soft_mask.float().to(device) contrast_count = features.shape[1] contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) if self.contrast_mode == 'one': anchor_feature = features[:, 0] anchor_count = 1 elif self.contrast_mode == 'all': anchor_feature = contrast_feature anchor_count = contrast_count else: raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) # compute logits anchor_dot_contrast = torch.div( torch.matmul(anchor_feature, contrast_feature.T), self.temperature) # for numerical stability logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) logits = anchor_dot_contrast - logits_max.detach() # tile mask soft_mask = soft_mask.repeat(anchor_count, contrast_count) # mask-out self-contrast cases logits_mask = torch.scatter( torch.ones_like(soft_mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0 ) soft_mask = soft_mask * logits_mask # compute log_prob exp_logits = torch.exp(logits) * logits_mask log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) # log_prob = logits - torch.nn.functional.log_softmax(logits, dim = 1) # compute mean of log-likelihood over positive mean_log_prob_pos = (soft_mask * log_prob).sum(1) / soft_mask.sum(1) # mean_log_prob_pos = (soft_mask * log_prob).sum(1) / logits.shape[0] # loss loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos loss = loss.view(anchor_count, batch_size).mean() return loss