"""Utility functions used by both `pipeline_parallel` and `tensor_parallel`""" import torch from apex.transformer import parallel_state def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) def divide(numerator, denominator): """Ensure that numerator is divisible by the denominator and return the division value.""" ensure_divisibility(numerator, denominator) return numerator // denominator def split_tensor_into_1d_equal_chunks(tensor): """Break a tensor into equal 1D chunks.""" data = tensor.view(-1) partition_size = torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size() start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() end_index = start_index + partition_size return data[start_index:end_index] def gather_split_1d_tensor(tensor): """Opposite of above function, gather values from model parallel ranks.""" world_size = parallel_state.get_tensor_model_parallel_world_size() numel = torch.numel(tensor) numel_gathered = world_size * numel gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)] torch.distributed.all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group()) return gathered