from .. import compat from . import torch_overrides import importlib import torch # if compat.variable_is_tensor() and not compat.tensor_is_variable(): MODULE = torch.Tensor # else: # MODULE = torch.autograd.Variable FP16_FUNCS = compat.filter_attrs(MODULE, [ '__matmul__', ]) FP32_FUNCS = compat.filter_attrs(MODULE, [ '__ipow__', '__pow__', '__rpow__', # Cast to fp32 before transfer to CPU 'cpu', ]) CASTS = compat.filter_attrs(MODULE, [ '__add__', '__div__', '__eq__', '__ge__', '__gt__', '__iadd__', '__idiv__', '__imul__', '__isub__', '__itruediv__', '__le__', '__lt__', '__mul__', '__ne__', '__radd__', '__rdiv__', '__rmul__', '__rsub__', '__rtruediv__', '__sub__', '__truediv__', ]) # None of these, but here to make code cleaner. SEQUENCE_CASTS = [] # We need to grab all the methods from torch_overrides and add them to # the Tensor lists as well, as almost all methods are duplicated # between `torch` and `torch.Tensor` (and check with `hasattr`, # because a few random ones aren't defined on Tensor) _self_mod = importlib.import_module(__name__) for attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']: lst = getattr(_self_mod, attrname) for fn in getattr(torch_overrides, attrname): if hasattr(MODULE, fn): lst.append(fn)