import torch from .. import utils MODULE = torch FP16_FUNCS = [ # Low level functions wrapped by torch.nn layers. # The wrapper layers contain the weights which are then passed in as a parameter # to these functions. 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d', 'conv_tbc', 'prelu', # BLAS 'addmm', 'addmv', 'addr', 'matmul', 'mm', 'mv', ] FP32_FUNCS = [ # Pointwise 'acos', 'asin', 'cosh', 'erfinv', 'exp', 'expm1', 'log', 'log10', 'log2', 'reciprocal', 'rsqrt', 'sinh', 'tan', # Other math 'pow', # Reduction 'cumprod', 'cumsum', 'dist', # 'mean', 'norm', 'prod', 'std', 'sum', 'var', # Misc 'renorm' ] version_strings = torch.__version__.split('.') version_major = version_strings[0] version_minor = version_strings[1] version_num = float(version_major + "." + version_minor) # Before torch 1.1, mean must be blacklisted. if version_num < 1.1: FP32_FUNCS.append('mean') # Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We # check the CUDA version -- if at least 9.1, then put the bmm # functions on the fp16 list. Otherwise, put them on the fp32 list. _bmms = ['addbmm', 'baddbmm', 'bmm'] if utils.is_cuda_enabled(): # workaround https://github.com/facebookresearch/maskrcnn-benchmark/issues/802 if utils.get_cuda_version() >= (9, 1, 0): FP16_FUNCS.extend(_bmms) else: FP32_FUNCS.extend(_bmms) # Multi-tensor fns that may need type promotion CASTS = [ # Multi-tensor math 'addcdiv', 'addcmul', 'atan2', 'cross', 'bilinear', 'dot', # Element-wise _or_ tensor-wise math 'add', 'div', 'mul', # Comparison 'eq', 'equal', 'ge', 'gt', 'le', 'lt', 'ne' ] # Functions that take sequence arguments. We need to inspect the whole # sequence and cast to the widest type. SEQUENCE_CASTS = [ 'cat', 'stack' ]