# coding=utf-8 # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """GPT-2 model.""" import enum import math import contextlib import json import torch import torch.nn.functional as F import apex.transformer.utils from apex.normalization import FusedLayerNorm as LayerNorm from apex.transformer.functional import FusedScaleMaskSoftmax from apex.transformer import tensor_parallel from apex.transformer import parallel_state from apex.transformer.testing.global_vars import get_args from apex.transformer.enums import LayerType from apex.transformer.enums import AttnType from apex.transformer.enums import AttnMaskType _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) _HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) _BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor) class ModelType(enum.Enum): encoder_or_decoder = 1 encoder_and_decoder = 2 ###### BIAS GELU FUSION/ NO AUTOGRAD ################ # 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2) -> 0.70710678 # sqrt(2/pi) -> 0.79788456 # this function is tanh approximation of gelu # actual gelu is: # x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) @torch.jit.script def bias_gelu(bias, y): x = bias + y return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) # gradient of tanh approximation of gelu # gradient of actual gelu is: # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) @torch.jit.script def bias_gelu_back(g, bias, y): x = bias + y tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) return ff * g class MegatronModule(torch.nn.Module): """Megatron specific extensions of torch Module with support for pipelining.""" def __init__(self, share_word_embeddings=True): super(MegatronModule, self).__init__() self.share_word_embeddings = share_word_embeddings def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): """Use this function to override the state dict for saving checkpoints.""" return self.state_dict(destination, prefix, keep_vars) def word_embeddings_weight(self): if not parallel_state.is_pipeline_last_stage(ignore_virtual=True) or \ parallel_state.get_pipeline_model_parallel_world_size() == 1: return self.language_model.embedding.word_embeddings.weight else: if not self.share_word_embeddings: raise Exception("word_embeddings_weight() called for last " "stage, but share_word_embeddings is false") return self.word_embeddings.weight def initialize_word_embeddings(self, init_method_normal): args = get_args() if not self.share_word_embeddings: raise Exception("initialize_word_embeddings() was called but " "share_word_embeddings is false") # This function just initializes the word embeddings in the final stage # when we are using pipeline parallelism. Nothing to do if we aren't # using pipeline parallelism. if args.pipeline_model_parallel_size == 1: return # Parameters are shared between the word embeddings layers, and the # heads at the end of the model. In a pipelined setup with more than # one stage, the initial embedding layer and the head are on different # workers, so we do the following: # 1. Create a second copy of word_embeddings on the last stage, with # initial parameters of 0.0. # 2. Do an all-reduce between the first and last stage to ensure that # the two copies of word_embeddings start off with the same # parameter values. # 3. In the training loop, before an all-reduce between the grads of # the two word_embeddings layers to ensure that every applied weight # update is the same on both stages. if parallel_state.is_pipeline_last_stage(): assert not parallel_state.is_pipeline_first_stage() self._word_embeddings_for_head_key = "word_embeddings_for_head" # set word_embeddings weights to 0 here, then copy first # stage's weights using all_reduce below. self.word_embeddings = tensor_parallel.VocabParallelEmbedding( args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std), use_cpu_initialization=args.use_cpu_initialization) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True # Zero out initial weights for decoder embedding. # NOTE: We don't currently support T5 with the interleaved schedule. if not parallel_state.is_pipeline_first_stage(ignore_virtual=True) and \ not parallel_state.is_pipeline_last_stage(ignore_virtual=True) and \ parallel_state.is_rank_in_embedding_group(): self.language_model.embedding.zero_parameters() # Ensure that first and last stages have the same initial parameter # values. if torch.distributed.is_initialized(): if parallel_state.is_rank_in_embedding_group(): torch.distributed.all_reduce(self.word_embeddings_weight().data, group=parallel_state.get_embedding_group()) # All-reduce other embeddings as well as necessary. The last stage # does not have these other embeddings, so just create placeholder # tensors of the right shape with all zeros. # NOTE: We don't currently support T5 with the interleaved schedule. if args.pipeline_model_parallel_split_rank is not None: # TODO: Support tokentype embedding. dimensions = (args.max_position_embeddings, args.hidden_size) if parallel_state.is_pipeline_last_stage(ignore_virtual=True): position_embeddings = torch.nn.Embedding(*dimensions).cuda() position_embeddings.weight.data.fill_(0) else: self.language_model.embedding.cuda() position_embeddings = self.language_model.embedding.position_embeddings torch.distributed.all_reduce(position_embeddings.weight.data, group=parallel_state.get_embedding_group()) else: print("WARNING! Distributed processes aren't initialized, so " "word embeddings in the last layer are not initialized. " "If you are just manipulating a model this is fine, but " "this needs to be handled manually. If you are training " "something is definitely wrong.") class GeLUFunction(torch.autograd.Function): @staticmethod # bias is an optional argument def forward(ctx, input, bias): ctx.save_for_backward(input, bias) return bias_gelu(bias, input) @staticmethod def backward(ctx, grad_output): input, bias = ctx.saved_tensors tmp = bias_gelu_back(grad_output, bias, input) return tmp, tmp bias_gelu_impl = GeLUFunction.apply def get_linear_layer(rows, columns, init_method): """Simple linear layer with weight initialization.""" layer = torch.nn.Linear(rows, columns) init_method(layer.weight) with torch.no_grad(): layer.bias.zero_() return layer def attention_mask_func(attention_scores, attention_mask): attention_scores.masked_fill_(attention_mask, -10000.0) return attention_scores @torch.jit.script def gelu_impl(x): """OpenAI's gelu implementation.""" return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) def openai_gelu(x): return gelu_impl(x) # This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter @torch.jit.script def erf_gelu(x): return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype)) def init_method_normal(sigma): """Init method based on N(0, sigma).""" def init_(tensor): return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) return init_ def scaled_init_method_normal(sigma, num_layers): """Init method based on N(0, sigma/sqrt(2*num_layers).""" std = sigma / math.sqrt(2.0 * num_layers) def init_(tensor): return torch.nn.init.normal_(tensor, mean=0.0, std=std) return init_ class ParallelMLP(MegatronModule): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. """ def __init__(self, init_method, output_layer_init_method): super().__init__() args = get_args() # Project to 4h. self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( args.hidden_size, args.ffn_hidden_size, gather_output=False, init_method=init_method, skip_bias_add=True, use_cpu_initialization=args.use_cpu_initialization) self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu if args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: self.activation_func = erf_gelu # Project back to h. self.dense_4h_to_h = tensor_parallel.RowParallelLinear( args.ffn_hidden_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, use_cpu_initialization=args.use_cpu_initialization ) def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) if self.bias_gelu_fusion: intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) else: intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) # [s, b, h] output, output_bias = self.dense_4h_to_h(intermediate_parallel) return output, output_bias class ParallelAttention(MegatronModule): """Parallel self-attention layer abstract class. Self-attention layer takes input with size [b, s, h] and returns output of the same size. """ def __init__( self, init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.padding, ): super().__init__() args = get_args() self.fp16 = args.fp16 self.bf16 = args.bf16 self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type self.params_dtype = args.params_dtype projection_size = args.kv_channels * args.num_attention_heads # Per attention head and per partition values. world_size = parallel_state.get_tensor_model_parallel_world_size() self.hidden_size_per_partition = apex.transformer.utils.divide(projection_size, world_size) self.hidden_size_per_attention_head = apex.transformer.utils.divide(projection_size, args.num_attention_heads) self.num_attention_heads_per_partition = apex.transformer.utils.divide(args.num_attention_heads, world_size) # Strided linear layer. if attention_type == AttnType.self_attn: self.query_key_value = tensor_parallel.ColumnParallelLinear( args.hidden_size, 3 * projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization) else: assert attention_type == AttnType.cross_attn self.query = tensor_parallel.ColumnParallelLinear( args.hidden_size, projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization) self.key_value = tensor_parallel.ColumnParallelLinear( args.hidden_size, 2 * projection_size, gather_output=False, init_method=init_method, use_cpu_initialization=args.use_cpu_initialization) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, self.bf16, self.attn_mask_type, args.masked_softmax_fusion, attention_mask_func, self.attention_softmax_in_fp32, coeff, ) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = tensor_parallel.RowParallelLinear( projection_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, use_cpu_initialization=args.use_cpu_initialization ) # Inference key-value memory self.inference_key_memory = None self.inference_value_memory = None self.inference_current_sequence_len = 0 def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( inference_max_sequence_len, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device(), ) def forward( self, hidden_states, attention_mask, encoder_output=None, set_inference_key_value_memory=False, inference_max_sequence_len=None, ): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= if set_inference_key_value_memory: assert inference_max_sequence_len and inference_max_sequence_len > 0 self.inference_key_memory = self._allocate_memory(inference_max_sequence_len, hidden_states.size(1)) self.inference_value_memory = self._allocate_memory(inference_max_sequence_len, hidden_states.size(1)) self.inference_current_sequence_len = 0 # Some consistency check. if inference_max_sequence_len: assert self.inference_current_sequence_len < self.inference_key_memory.size(0) assert inference_max_sequence_len == self.inference_key_memory.size(0) # This is added for safety. In case inference_max_sequence_len # is not provided, make sure there is no potential memory left # from previous inference. if not inference_max_sequence_len: self.inference_key_memory = None self.inference_value_memory = None # ===================== # Query, Key, and Value # ===================== if self.attention_type == AttnType.self_attn: # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head, ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3) else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer, _ = self.key_value(encoder_output) # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] new_tensor_shape = mixed_kv_layer.size()[:-1] + ( self.num_attention_heads_per_partition, 2 * self.hidden_size_per_attention_head, ) mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) # Attention head [sq, b, h] --> [sq, b, hp] query_layer, _ = self.query(hidden_states) # [sq, b, hp] --> [sq, b, np, hn] new_tensor_shape = query_layer.size()[:-1] + ( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) query_layer = query_layer.view(*new_tensor_shape) # =================================================== # Adjust key, value, and attention mask for inference # =================================================== if inference_max_sequence_len: # Adjust the range variables. start = self.inference_current_sequence_len self.inference_current_sequence_len += key_layer.size(0) end = self.inference_current_sequence_len # Copy key and values. self.inference_key_memory[start:end, ...] = key_layer self.inference_value_memory[start:end, ...] = value_layer key_layer = self.inference_key_memory[:end, ...] value_layer = self.inference_value_memory[:end, ...] # Adjust attention mask attention_mask = attention_mask[..., start:end, :end] # =================================== # Raw attention scores. [b, np, s, s] # =================================== # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocting result tensor: [b * np, sq, sk] matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=torch.cuda.current_device(), ) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_result, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor), ) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # =========================== # Attention probs and dropout # =========================== # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with tensor_parallel.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) # ================= # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) return output, bias @torch.jit.script def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: out = torch.nn.functional.dropout(x + bias, p=prob, training=training) out = residual + out return out def get_bias_dropout_add(training): def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) return _bias_dropout_add @torch.jit.script def bias_dropout_add_fused_train( x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float ) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script def bias_dropout_add_fused_inference( x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float ) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, False) class ParallelTransformerLayer(MegatronModule): """A single transformer layer. Transformer layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__( self, init_method, output_layer_init_method, layer_number, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, ): args = get_args() super().__init__() self.layer_number = layer_number self.layer_type = layer_type self.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm self.bf16 = args.bf16 self.fp32_residual_connection = args.fp32_residual_connection # Layernorm on the input data. self.input_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon) # Self attention. self.self_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=self_attn_mask_type, ) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion # Layernorm on the attention output self.post_attention_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon) if self.layer_type == LayerType.decoder: self.inter_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, attention_type=AttnType.cross_attn ) # Layernorm on the attention output. self.post_inter_attention_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon) # MLP self.mlp = ParallelMLP(init_method, output_layer_init_method) def forward( self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None, ): # hidden_states: [b, s, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = self.self_attention( layernorm_output, attention_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len, ) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) # re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout ) # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) if self.layer_type == LayerType.decoder: attention_output, attention_bias = self.inter_attention( layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output ) # residual connection if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input # re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout ) # Layer norm post the decoder attention layernorm_output = self.post_inter_attention_layernorm(layernorm_input) # MLP. mlp_output, mlp_bias = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input # re-enable torch grad to enable fused optimization. with torch.enable_grad(): output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) return output class ParallelTransformer(MegatronModule): """Transformer class.""" def __init__( self, init_method, output_layer_init_method, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, pre_process=True, post_process=True, ): super().__init__() args = get_args() self.bf16 = args.bf16 self.fp32_residual_connection = args.fp32_residual_connection self.pre_process = pre_process self.post_process = post_process self.input_tensor = None # Store activation checkpointing flag. self.activations_checkpoint_method = args.activations_checkpoint_method self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers self.distribute_checkpointed_activations = args.distribute_checkpointed_activations num_layers = args.num_layers # Number of layers. assert ( num_layers % parallel_state.get_pipeline_model_parallel_world_size() == 0 ), "num_layers must be divisible by pipeline_model_parallel_size" self.num_layers = num_layers // parallel_state.get_pipeline_model_parallel_world_size() # Transformer layers. def build_layer(layer_number): return ParallelTransformerLayer( init_method, output_layer_init_method, layer_number, layer_type=layer_type, self_attn_mask_type=self_attn_mask_type, ) if args.virtual_pipeline_model_parallel_size is not None: assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, ( "num_layers_per_stage must be divisible by " "virtual_pipeline_model_parallel_size" ) # Number of layers in each model chunk is the number of layers in the stage, # divided by the number of model chunks in a stage. self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of # layers to stages like (each list is a model chunk): # Stage 0: [0] [2] [4] [6] # Stage 1: [1] [3] [5] [7] # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of # layers to stages like (each list is a model chunk): # Stage 0: [0, 1] [4, 5] # Stage 1: [2, 3] [6, 7] offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * ( args.num_layers // args.virtual_pipeline_model_parallel_size ) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers) else: # Each stage gets a contiguous set of layers. offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)]) if self.post_process: # Final layer norm before output. self.final_layernorm = LayerNorm(args.hidden_size, eps=args.layernorm_epsilon) def _get_layer(self, layer_number): return self.layers[layer_number] def _checkpointed_forward(self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask): """Forward method with activation checkpointing.""" def custom(start, end): def custom_forward(*inputs): x_ = inputs[0] attention_mask = inputs[1] encoder_output = inputs[2] enc_dec_attn_mask = inputs[3] for index in range(start, end): layer = self._get_layer(index) x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask) return x_ return custom_forward def distribute_checkpointed_activations_helper(layer_number): """Distribute checkpointed activations across the tensor model Parallel ranks if the `distribute-checkpointed-activations is on and either of the following conditions is met: - it is not the first layer in the in the pipeline stage. The first layer is used in the pipeline parallelism and changing its shape throws error in the backward pass. - we are at the first pipline stage so the input tensor is not used in pipeline parallelism. Note that no pipeline parallelism is a special case of this. """ not_first_layer_in_pipeline_stage = layer_number > 0 is_first_pipeline_stage = parallel_state.get_pipeline_model_parallel_rank() == 0 return self.distribute_checkpointed_activations and ( not_first_layer_in_pipeline_stage or is_first_pipeline_stage ) if self.activations_checkpoint_method == "uniform": # Uniformly divide the total number of Transformer layers and checkpoint # the input activation of each divided chunk. # A method to further reduce memory usage reducing checkpoints. l = 0 while l < self.num_layers: hidden_states = tensor_parallel.checkpoint( custom(l, l + self.activations_checkpoint_num_layers), distribute_checkpointed_activations_helper(l), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, ) l += self.activations_checkpoint_num_layers elif self.activations_checkpoint_method == "block": # Checkpoint the input activation of only a set number of individual # Transformer layers and skip the rest. # A method fully use the device memory removing redundant re-computation. for l in range(self.num_layers): if l < self.activations_checkpoint_num_layers: hidden_states = tensor_parallel.checkpoint( custom(l, l + 1), distribute_checkpointed_activations_helper(l), hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, ) else: hidden_states = custom(l, l + 1)(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: raise ValueError("Invalid activation checkpoint method.") return hidden_states def set_input_tensor(self, input_tensor): """Set input tensor to be used instead of forward()'s input. When doing pipeline parallelism the input from the previous stage comes from communication, not from the input, so the model's forward_step_func won't have it. This function is thus used by internal code to bypass the input provided by the forward_step_func""" self.input_tensor = input_tensor def forward( self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, set_inference_key_value_memory=False, inference_max_sequence_len=None, ): # Checks. if inference_max_sequence_len: assert self.activations_checkpoint_method is None, "inference does not work with activation checkpointing" if self.pre_process: # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # If the input flag for fp32 residual connection is set, convert for float. if self.fp32_residual_connection: hidden_states = hidden_states.transpose(0, 1).contiguous().float() # Otherwise, leave it as is. else: hidden_states = hidden_states.transpose(0, 1).contiguous() else: # See set_input_tensor() hidden_states = self.input_tensor if encoder_output is not None: encoder_output = encoder_output.transpose(0, 1).contiguous() if self.activations_checkpoint_method is not None: hidden_states = self._checkpointed_forward(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: for index in range(self.num_layers): layer = self._get_layer(index) hidden_states = layer( hidden_states, attention_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len, ) # Final layer norm. if self.post_process: # Reverting data format change [s b h] --> [b s h]. hidden_states = hidden_states.transpose(0, 1).contiguous() output = self.final_layernorm(hidden_states) else: output = hidden_states return output def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): """LM logits using word embedding weights.""" # Parallel logits. input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_) # Matrix multiply. if bias is None: logits_parallel = F.linear(input_parallel, word_embeddings_weight) else: logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) # Gather if needed. if parallel_output: return logits_parallel return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel) def get_language_model( num_tokentypes, add_pooler, encoder_attn_mask_type, init_method=None, scaled_init_method=None, add_encoder=True, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, pre_process=True, post_process=True, ): """Build language model and return along with the key to save.""" args = get_args() if init_method is None: init_method = init_method_normal(args.init_method_std) if scaled_init_method is None: scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) # Language model. language_model = TransformerLanguageModel( init_method, scaled_init_method, encoder_attn_mask_type, num_tokentypes=num_tokentypes, add_encoder=add_encoder, add_decoder=add_decoder, decoder_attn_mask_type=decoder_attn_mask_type, add_pooler=add_pooler, pre_process=pre_process, post_process=post_process, ) # key used for checkpoints. language_model_key = "language_model" return language_model, language_model_key class Pooler(MegatronModule): """Pooler layer. Pool hidden states of a specific token (for example start of the sequence) and add a linear transformation followed by a tanh. Arguments: hidden_size: hidden size init_method: weight initialization method for the linear layer. bias is set to zero. """ def __init__(self, hidden_size, init_method): super(Pooler, self).__init__() self.dense = get_linear_layer(hidden_size, hidden_size, init_method) def forward(self, hidden_states, sequence_index=0): # hidden_states: [b, s, h] # sequence_index: index of the token to pool. pooled = hidden_states[:, sequence_index, :] pooled = self.dense(pooled) pooled = torch.tanh(pooled) return pooled class Embedding(MegatronModule): """Language model embeddings. Arguments: hidden_size: hidden size vocab_size: vocabulary size max_sequence_length: maximum size of sequence. This is used for positional embedding embedding_dropout_prob: dropout probability for embeddings init_method: weight initialization method num_tokentypes: size of the token-type embeddings. 0 value will ignore this embedding """ def __init__( self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, init_method, num_tokentypes=0 ): super().__init__() self.hidden_size = hidden_size self.init_method = init_method self.num_tokentypes = num_tokentypes args = get_args() # Word embeddings (parallel). self.word_embeddings = tensor_parallel.VocabParallelEmbedding( vocab_size, self.hidden_size, init_method=self.init_method, use_cpu_initialization=args.use_cpu_initialization ) self._word_embeddings_key = "word_embeddings" # Position embedding (serial). self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size) self._position_embeddings_key = "position_embeddings" # Initialize the position embeddings. self.init_method(self.position_embeddings.weight) # Token type embedding. # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. self._tokentype_embeddings_key = "tokentype_embeddings" if self.num_tokentypes > 0: self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. self.init_method(self.tokentype_embeddings.weight) else: self.tokentype_embeddings = None # Embeddings dropout self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) if torch.distributed.get_rank() == 0: print("FINISH WORD EMBEDDING", self.word_embeddings) def zero_parameters(self): """Zero out all parameters in embedding.""" self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True self.position_embeddings.weight.data.fill_(0) self.position_embeddings.weight.shared = True if self.num_tokentypes > 0: self.tokentype_embeddings.weight.data.fill_(0) self.tokentype_embeddings.weight.shared = True def add_tokentype_embeddings(self, num_tokentypes): """Add token-type embedding. This function is provided so we can add token-type embeddings in case the pretrained model does not have it. This allows us to load the model normally and then add this embedding. """ if self.tokentype_embeddings is not None: raise Exception("tokentype embeddings is already initialized") if torch.distributed.get_rank() == 0: print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True) self.num_tokentypes = num_tokentypes self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. self.init_method(self.tokentype_embeddings.weight) def forward(self, input_ids, position_ids, tokentype_ids=None): # Embeddings. words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings if tokentype_ids is not None: assert self.tokentype_embeddings is not None embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) else: assert self.tokentype_embeddings is None # Dropout. embeddings = self.embedding_dropout(embeddings) return embeddings def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): """For easy load.""" state_dict_ = {} state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars) state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars) if self.num_tokentypes > 0: state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict( destination, prefix, keep_vars ) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Word embedding. if self._word_embeddings_key in state_dict: state_dict_ = state_dict[self._word_embeddings_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if "word_embeddings" in key: state_dict_[key.split("word_embeddings.")[1]] = state_dict[key] self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. if self._position_embeddings_key in state_dict: state_dict_ = state_dict[self._position_embeddings_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if "position_embeddings" in key: state_dict_[key.split("position_embeddings.")[1]] = state_dict[key] self.position_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. if self.num_tokentypes > 0: state_dict_ = {} if self._tokentype_embeddings_key in state_dict: state_dict_ = state_dict[self._tokentype_embeddings_key] else: # for backward compatibility. for key in state_dict.keys(): if "tokentype_embeddings" in key: state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key] if len(state_dict_.keys()) > 0: self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) else: print( "***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True ) class TransformerLanguageModel(MegatronModule): """Transformer language model. Arguments: transformer_hparams: transformer hyperparameters vocab_size: vocabulary size max_sequence_length: maximum size of sequence. This is used for positional embedding embedding_dropout_prob: dropout probability for embeddings num_tokentypes: size of the token-type embeddings. 0 value will ignore this embedding """ def __init__( self, init_method, output_layer_init_method, encoder_attn_mask_type, num_tokentypes=0, add_encoder=True, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, add_pooler=False, pre_process=True, post_process=True, ): super().__init__() args = get_args() self.pre_process = pre_process self.post_process = post_process self.hidden_size = args.hidden_size self.num_tokentypes = num_tokentypes self.init_method = init_method self.add_encoder = add_encoder self.encoder_attn_mask_type = encoder_attn_mask_type self.add_decoder = add_decoder self.decoder_attn_mask_type = decoder_attn_mask_type self.add_pooler = add_pooler self.encoder_hidden_state = None # Embeddings. if self.pre_process: self.embedding = Embedding( self.hidden_size, args.padded_vocab_size, args.max_position_embeddings, args.hidden_dropout, self.init_method, self.num_tokentypes, ) self._embedding_key = "embedding" # Transformer. # Encoder (usually set to True, False if part of an encoder-decoder # architecture and in encoder-only stage). if self.add_encoder: self.encoder = ParallelTransformer( self.init_method, output_layer_init_method, self_attn_mask_type=self.encoder_attn_mask_type, pre_process=self.pre_process, post_process=self.post_process, ) self._encoder_key = "encoder" else: self.encoder = None # Decoder (usually set to False, True if part of an encoder-decoder # architecture and in decoder-only stage). if self.add_decoder: # Temporary assertion until we verify correctness of pipeline parallelism # implementation of T5. assert ( args.pipeline_model_parallel_size == 1 ), "pipeline parallelism is not supported in the presence of decoder" self.decoder = ParallelTransformer( self.init_method, output_layer_init_method, layer_type=LayerType.decoder, self_attn_mask_type=self.decoder_attn_mask_type, pre_process=self.pre_process, post_process=self.post_process, ) self._decoder_key = "decoder" else: self.decoder = None if self.post_process: # Pooler. if self.add_pooler: self.pooler = Pooler(self.hidden_size, self.init_method) self._pooler_key = "pooler" def set_input_tensor(self, input_tensor): """ See megatron.model.transformer.set_input_tensor()""" # This is usually handled in schedules.py but some inference code still # gives us non-lists or None if not isinstance(input_tensor, list): input_tensor = [input_tensor] if self.add_encoder and self.add_decoder: assert ( len(input_tensor) == 1 ), "input_tensor should only be length 1 for stage with both encoder and decoder" self.encoder.set_input_tensor(input_tensor[0]) elif self.add_encoder: assert len(input_tensor) == 1, "input_tensor should only be length 1 for stage with only encoder" self.encoder.set_input_tensor(input_tensor[0]) elif self.add_decoder: if len(input_tensor) == 2: self.decoder.set_input_tensor(input_tensor[0]) self.encoder_hidden_state = input_tensor[1] elif len(input_tensor) == 1: self.decoder.set_input_tensor(None) self.encoder_hidden_state = input_tensor[0] else: raise Exception("input_tensor must have either length 1 or 2") else: raise Exception("Stage must have at least either encoder or decoder") def forward( self, enc_input_ids, enc_position_ids, enc_attn_mask, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, enc_dec_attn_mask=None, tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False, ): # Encoder embedding. if self.pre_process: encoder_input = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids) else: encoder_input = None # Run encoder. if enc_hidden_states is None: if self.encoder is not None: encoder_output = self.encoder( encoder_input, enc_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len, ) else: encoder_output = self.encoder_hidden_state else: encoder_output = enc_hidden_states.to(encoder_input.dtype) if self.post_process: if self.add_pooler: pooled_output = self.pooler(encoder_output, pooling_sequence_index) # output_enc_hidden refers to when we just need the encoder's # output. For example, it is helpful to compute # similarity between two sequences by average pooling if not self.add_decoder or output_enc_hidden: if self.add_pooler and self.post_process: return encoder_output, pooled_output else: return encoder_output # Decoder embedding. if self.pre_process: decoder_input = self.embedding(dec_input_ids, dec_position_ids) else: decoder_input = None # Run decoder. decoder_output = self.decoder( decoder_input, dec_attn_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len, ) if self.add_pooler and self.post_process: return decoder_output, encoder_output, pooled_output else: return decoder_output, encoder_output def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): """For easy load.""" state_dict_ = {} if self.pre_process: state_dict_[self._embedding_key] = self.embedding.state_dict_for_save_checkpoint( destination, prefix, keep_vars ) if self.add_encoder: state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars) if self.post_process: if self.add_pooler: state_dict_[self._pooler_key] = self.pooler.state_dict_for_save_checkpoint( destination, prefix, keep_vars ) if self.add_decoder: state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Embedding. if self.pre_process: if self._embedding_key in state_dict: state_dict_ = state_dict[self._embedding_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if "_embeddings" in key: state_dict_[key] = state_dict[key] self.embedding.load_state_dict(state_dict_, strict=strict) # Encoder. if self.add_encoder: if self._encoder_key in state_dict: state_dict_ = state_dict[self._encoder_key] # For backward compatibility. elif "transformer" in state_dict: state_dict_ = state_dict["transformer"] else: # For backward compatibility. state_dict_ = {} for key in state_dict.keys(): if "transformer." in key: state_dict_[key.split("transformer.")[1]] = state_dict[key] # For backward compatibility. state_dict_self_attention = {} for key in state_dict_.keys(): if ".attention." in key: state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = state_dict_[key] else: state_dict_self_attention[key] = state_dict_[key] state_dict_ = state_dict_self_attention self.encoder.load_state_dict(state_dict_, strict=strict) # Pooler. if self.post_process: if self.add_pooler: assert "pooler" in state_dict, "could not find data for pooler in the checkpoint" self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict) # Decoder. if self.add_decoder: assert "decoder" in state_dict, "could not find data for pooler in the checkpoint" self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict) def post_language_model_processing(lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy): # Output. output = parallel_lm_logits(lm_output, logit_weights, parallel_output) if labels is None: return output else: if fp16_lm_cross_entropy: assert output.dtype == torch.half loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) else: loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) return loss def module_size(m: torch.nn.Module, only_trainable: bool = False): """ returns the total number of parameters used by `m` (only counting shared parameters once); if `only_trainable` is True, then only includes parameters with `requires_grad = True` """ parameters = list(m.parameters()) if only_trainable: parameters = [p for p in parameters if p.requires_grad] unique = {p.data_ptr(): p for p in parameters}.values() return sum(p.numel() for p in unique) class GPTModel(MegatronModule): """GPT-2 Language model.""" def __init__(self, num_tokentypes=0, parallel_output=True, pre_process=True, post_process=True, cpu_offload=False): super(GPTModel, self).__init__() args = get_args() self.parallel_output = parallel_output self.pre_process = pre_process self.post_process = post_process self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.cpu_offload = cpu_offload self.language_model, self._language_model_key = get_language_model( num_tokentypes=num_tokentypes, add_pooler=False, encoder_attn_mask_type=AttnMaskType.causal, init_method=init_method_normal(args.init_method_std), scaled_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers), pre_process=self.pre_process, post_process=self.post_process, ) self.initialize_word_embeddings(init_method_normal) def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.language_model.set_input_tensor(input_tensor) def forward( self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, set_inference_key_value_memory=False, inference_max_sequence_len=None, ): with torch.autograd.graph.save_on_cpu() if self.cpu_offload else contextlib.nullcontext(): lm_output = self.language_model( input_ids, position_ids, attention_mask, set_inference_key_value_memory=set_inference_key_value_memory, inference_max_sequence_len=inference_max_sequence_len, ) if self.post_process: return post_language_model_processing( lm_output, labels, self.word_embeddings_weight(), self.parallel_output, self.fp16_lm_cross_entropy ) else: return lm_output def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): state_dict_ = {} state_dict_[self._language_model_key] = self.language_model.state_dict_for_save_checkpoint( destination, prefix, keep_vars ) # Save word_embeddings. if self.post_process and not self.pre_process: state_dict_[self._word_embeddings_for_head_key] = self.word_embeddings.state_dict( destination, prefix, keep_vars ) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Load word_embeddings. if self.post_process and not self.pre_process: self.word_embeddings.load_state_dict(state_dict[self._word_embeddings_for_head_key], strict=strict) if self._language_model_key in state_dict: state_dict = state_dict[self._language_model_key] self.language_model.load_state_dict(state_dict, strict=strict) def gpt_model_provider(pre_process=True, post_process=False, cpu_offload=False): model = GPTModel(num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process, cpu_offload=cpu_offload) if torch.distributed.get_rank() == 0: init_dict = {"pre_process":pre_process, "post_process":post_process, "cpu_offload":cpu_offload} print("Initialized GPT-2 w/:", json.dumps(init_dict)) n_params = module_size(model) * parallel_state.get_tensor_model_parallel_world_size() * parallel_state.get_pipeline_model_parallel_world_size() print("Number of Parameters:", n_params) return model