# TODO: think about the following two. They do weird things. # - torch.nn.utils.clip_grad (but it should always be fp32 anyway) # - torch.nn.utils.weight_norm # Notes: # F.instance_norm uses batch_norm internally. Which correctly handles # fp16 in/out with fp32 weights. So we shouldn't do anything for # either of these. # F.normalize calls `input.norm()` internally, so it's redundant, but # kept here in case impl. changes. # F.cosine_similarity is same: calls `x.norm()` internally. import torch.nn.functional MODULE = torch.nn.functional FP16_FUNCS = [ 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d', 'conv_tbc', # Undocumented / maybe new? 'linear', ] FP32_FUNCS = [ # Interpolation/Upsampling TODO: Remove for 1.2 'interpolate', 'grid_sample', # Pointwise 'softplus', 'softmin', 'log_softmax', 'softmax', 'gelu', # Normalization 'layer_norm', 'group_norm', 'local_response_norm', 'normalize', 'cosine_similarity', # Loss functions # TODO: which of these can be fp16? 'poisson_nll_loss', 'cosine_embedding_loss', 'cross_entropy', 'hinge_embedding_loss', 'kl_div', 'l1_loss', 'mse_loss', 'margin_ranking_loss', 'multilabel_margin_loss', 'multilabel_soft_margin_loss', 'multi_margin_loss', 'nll_loss', 'binary_cross_entropy_with_logits', 'smooth_l1_loss', 'soft_margin_loss', 'triplet_margin_loss', 'ctc_loss' ] BANNED_FUNCS = [ ('binary_cross_entropy', ("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` " "It requires that the output of the previous function be already a FloatTensor. \n\n" "Most models have a Sigmoid right before BCELoss. In that case, you can use\n" " torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer " "that is compatible with amp.\nAnother option is to add\n" " amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n" "If you _really_ know what you are doing, you can disable this warning by passing " "allow_banned=True to `amp.init()`.")) ]