import torch import numpy as np from torch.nn.modules.batchnorm import _BatchNorm import bnp class bn_NHWC_impl(torch.autograd.Function): @staticmethod def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream): if is_train: ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv) ctx.epsilon = epsilon ctx.momentum = mom ctx.ret_cta = ret_cta ctx.fuse_relu = fuse_relu ctx.my_data = my_data ctx.pair_data = pair_data ctx.magic = magic ctx.pair_data2 = pair_data2 ctx.pair_data3 = pair_data3 ctx.bn_group = bn_group ctx.bwd_occup = bwd_occup ctx.bwd_grid_x = bwd_grid_x ctx.multi_stream = multi_stream res = bnp.bn_fwd_nhwc(x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream) return res else: return bnp.bn_fwd_eval_nhwc(x, s, b, rm, riv, ret_cta, bn_group, mom, epsilon, fuse_relu) @staticmethod def backward(ctx, grad_y): x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables epsilon = ctx.epsilon mom = ctx.momentum ret_cta = ctx.ret_cta fuse_relu = ctx.fuse_relu my_data = ctx.my_data pair_data = ctx.pair_data magic = ctx.magic pair_data2 = ctx.pair_data2 pair_data3 = ctx.pair_data3 bn_group = ctx.bn_group bwd_occup = ctx.bwd_occup bwd_grid_x = ctx.bwd_grid_x multi_stream = ctx.multi_stream dx, dscale, dbias = bnp.bn_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream) return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None class bn_addrelu_NHWC_impl(torch.autograd.Function): @staticmethod def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom, epsilon, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream): if is_train: bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y) ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask) ctx.epsilon = epsilon ctx.momentum = mom ctx.ret_cta = ret_cta ctx.my_data = my_data ctx.pair_data = pair_data ctx.magic = magic ctx.pair_data2 = pair_data2 ctx.pair_data3 = pair_data3 ctx.bn_group = bn_group ctx.bwd_occup = bwd_occup ctx.bwd_grid_x = bwd_grid_x ctx.multi_stream = multi_stream res = bnp.bn_addrelu_fwd_nhwc(x, z, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream) return res else: return bnp.bn_addrelu_fwd_eval_nhwc(x, z, s, b, rm, riv, ret_cta, bn_group, mom, epsilon) @staticmethod def backward(ctx, grad_y): x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables epsilon = ctx.epsilon mom = ctx.momentum ret_cta = ctx.ret_cta my_data = ctx.my_data pair_data = ctx.pair_data magic = ctx.magic pair_data2 = ctx.pair_data2 pair_data3 = ctx.pair_data3 bn_group = ctx.bn_group bwd_occup = ctx.bwd_occup bwd_grid_x = ctx.bwd_grid_x multi_stream = ctx.multi_stream dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream) return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None class BatchNorm2d_NHWC(_BatchNorm): # if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True def __init__(self, num_features, fuse_relu=False, bn_group=1, max_cta_per_sm=2, cta_launch_margin=12, multi_stream=False): super(BatchNorm2d_NHWC, self).__init__(num_features) self.fuse_relu = fuse_relu self.multi_stream = multi_stream self.minibatch_mean = torch.cuda.FloatTensor(num_features) self.minibatch_riv = torch.cuda.FloatTensor(num_features) #defaut to distributed bn disabled self.bn_group = bn_group self.max_cta_per_sm = max_cta_per_sm #used only in training fwd and bwd self.cta_launch_margin = cta_launch_margin #used only in training fwd and bwd self.my_data = None self.pair_data = None self.pair_data2 = None self.pair_data3 = None self.local_rank = 0 self.magic = torch.IntTensor([0]) #calculate cta per sm occupancies assert(max_cta_per_sm>0) # won't be able to do much with 0 CTAs :) self.fwd_occupancy = min(bnp.bn_fwd_nhwc_occupancy(), max_cta_per_sm) self.bwd_occupancy = min(bnp.bn_bwd_nhwc_occupancy(), max_cta_per_sm) self.addrelu_fwd_occupancy = min(bnp.bn_addrelu_fwd_nhwc_occupancy(), max_cta_per_sm) self.addrelu_bwd_occupancy = min(bnp.bn_addrelu_bwd_nhwc_occupancy(), max_cta_per_sm) #calculate grid dimentions based on occupancy numbers mp_count = torch.cuda.get_device_properties(None).multi_processor_count self.fwd_grid_dim_x = max(mp_count*self.fwd_occupancy - cta_launch_margin , 1) self.bwd_grid_dim_x = max(mp_count*self.bwd_occupancy - cta_launch_margin , 1) self.addrelu_fwd_grid_dim_x = max(mp_count*self.addrelu_fwd_occupancy - cta_launch_margin , 1) self.addrelu_bwd_grid_dim_x = max(mp_count*self.addrelu_bwd_occupancy - cta_launch_margin , 1) self.grid_dim_y = (num_features + 63) // 64 # allocate scratch space used by implementation # TODO: scratch space that is not supposed to be exposed at user code. We only need one time initialization, the # same buffer could be reused in future iterations. Currently we exposed it here instead of requesting new # buffer from cache allocator to avoid unnecessary initialization at future iterations. self.ret_cta = torch.cuda.ByteTensor(8192).fill_(0) #FIXME: turn pair handles into an array if bn_group>1: local_rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() assert(world_size >= bn_group) assert(world_size % bn_group == 0) bn_sync_steps = 1 if (bn_group==4): bn_sync_steps = 2 if (bn_group==8): bn_sync_steps = 3 self.ipc_buffer = torch.cuda.ByteTensor(bnp.get_buffer_size(bn_sync_steps)) self.my_data = bnp.get_data_ptr(self.ipc_buffer) # we are walking on very thin ice here by utilizing internal `_share_cuda_()` self.storage = self.ipc_buffer.storage() self.share_cuda = self.storage._share_cuda_() internal_cuda_mem = self.share_cuda # internal_cuda_mem[1]: ipc_mem_handle my_handle = torch.cuda.ByteTensor(np.frombuffer(internal_cuda_mem[1], dtype=np.uint8)) # internal_cuda_mem[3]: offset my_offset = torch.cuda.IntTensor([internal_cuda_mem[3]]) handles_all = torch.empty(world_size, my_handle.size(0), dtype=my_handle.dtype, device=my_handle.device) handles_l = list(handles_all.unbind(0)) torch.distributed.all_gather(handles_l, my_handle) offsets_all = torch.empty(world_size, my_offset.size(0), dtype=my_offset.dtype, device=my_offset.device) offsets_l = list(offsets_all.unbind(0)) torch.distributed.all_gather(offsets_l, my_offset) #whom do I actually care about? that would be local_rank XOR 1 self.pair_handle = handles_l[local_rank ^ 1].cpu().contiguous() pair_offset = offsets_l[local_rank ^ 1].cpu() self.pair_data = bnp.get_remote_data_ptr(self.pair_handle, pair_offset) if bn_group>2: self.pair_handle2 = handles_l[local_rank ^ 2].cpu().contiguous() pair_offset2 = offsets_l[local_rank ^ 2].cpu() self.pair_data2 = bnp.get_remote_data_ptr(self.pair_handle2, pair_offset2) if bn_group>4: self.pair_handle3 = handles_l[local_rank ^ 4].cpu().contiguous() pair_offset3 = offsets_l[local_rank ^ 4].cpu() self.pair_data3 = bnp.get_remote_data_ptr(self.pair_handle3, pair_offset3) #FIXME: get magic value into C code and eliminate from here self.magic = torch.IntTensor([2]) self.local_rank = local_rank def forward(self, x, z=None): if z is not None: assert(self.fuse_relu==True) return bn_addrelu_NHWC_impl.apply(x, z, self.weight, self.bias, self.running_mean, self.running_var, self.minibatch_mean, self.minibatch_riv, self.grid_dim_y, self.ret_cta, self.momentum, self.eps, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3, self.addrelu_fwd_occupancy, self.addrelu_fwd_grid_dim_x, self.addrelu_bwd_occupancy, self.addrelu_bwd_grid_dim_x, self.multi_stream) else: return bn_NHWC_impl.apply(x, self.weight, self.bias, self.running_mean, self.running_var, self.minibatch_mean, self.minibatch_riv, self.ret_cta, self.momentum, self.eps, self.fuse_relu, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3, self.fwd_occupancy, self.fwd_grid_dim_x, self.bwd_occupancy, self.bwd_grid_dim_x, self.multi_stream) def __del__(self): if self.bn_group>1: bnp.close_remote_data(self.pair_handle) if self.bn_group>2: bnp.close_remote_data(self.pair_handle2) if self.bn_group>4: bnp.close_remote_data(self.pair_handle3)