from collections import defaultdict import torch from apex.transformer import parallel_state class GradScaler(torch.cuda.amp.GradScaler): """ Gradient scaler for model-parallel inf check. The inf in gradients are checked across tensor-parallel ranks in (1) executing optimizer step and (2) gradient scaler update. """ def __init__( self, init_scale=2.0 ** 16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True ): super().__init__( init_scale=init_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, enabled=enabled, ) def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): retval = None found_inf = torch.cuda.FloatTensor([sum(v.item() for v in optimizer_state["found_inf_per_device"].values())]) # Update across all model parallel instances. torch.distributed.all_reduce( found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group() ) if found_inf.item() == 0: retval = optimizer.step(*args, **kwargs) return retval def update(self, new_scale=None): """ Updates the scale factor. If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, the scale is multiplied by ``growth_factor`` to increase it. Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not used directly, it's used to fill GradScaler's internal scale tensor. So if ``new_scale`` was a tensor, later in-place changes to that tensor will not further affect the scale GradScaler uses internally.) Args: new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. .. warning:: :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has been invoked for all optimizers used this iteration. """ if not self._enabled: return _scale, _growth_tracker = self._check_scale_growth_tracker("update") if new_scale is not None: # Accept a new user-defined scale. if isinstance(new_scale, float): self._scale.fill_(new_scale) # type: ignore[union-attr] else: reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason self._scale.copy_(new_scale) # type: ignore[union-attr] else: # Consume shared inf/nan data collected from optimizers to update the scale. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. found_infs = [ found_inf.to(device=_scale.device, non_blocking=True) for state in self._per_optimizer_states.values() for found_inf in state["found_inf_per_device"].values() ] assert len(found_infs) > 0, "No inf checks were recorded prior to update." found_inf_combined = found_infs[0] # Update across all model parallel instances. torch.distributed.all_reduce( found_inf_combined, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group() ) if len(found_infs) > 1: for i in range(1, len(found_infs)): found_inf = found_infs[i] # Update across all model parallel instances. torch.distributed.all_reduce( found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group() ) found_inf_combined += found_inf torch._amp_update_scale_( _scale, _growth_tracker, found_inf_combined, self._growth_factor, self._backoff_factor, self._growth_interval, ) # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict(torch.cuda.amp.grad_scaler._refresh_per_optimizer_state)