import torch import torch.nn.functional as F def cross_entropy(input, target, weight=None, reduction='mean', ignore_index=255): """ logSoftmax_with_loss :param input: torch.Tensor, N*C*H*W :param target: torch.Tensor, N*1*H*W,/ N*H*W :param weight: torch.Tensor, C :return: torch.Tensor [0] """ target = target.long() if target.dim() == 4: target = torch.squeeze(target, dim=1) if input.shape[-1] != target.shape[-1]: input = F.interpolate(input, size=target.shape[1:], mode='bilinear', align_corners=True) return F.cross_entropy(input=input, target=target, weight=weight, ignore_index=ignore_index, reduction=reduction)