from collections import OrderedDict import torch from apex.optimizers import FusedAdam from apex.contrib.sparsity import ASP # # Reference run for checkpointing test (part1 + part2) # def build_model(args): od = OrderedDict() for i in range(args.num_layers): if i == 0: od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features) od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features]) elif i == args.num_layers-1: od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features) od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features]) else: od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features) od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features]) return torch.nn.Sequential(od) def train_step(args, model, optimizer, input_batch, target_batch, step): predicted_target = model(input_batch) loss = ((predicted_target-target_batch)**2).sum() loss.backward() optimizer.step() optimizer.zero_grad() step = step + 1 #print("Step %d :: loss=%e" % (step, loss.item())) return step def train_loop(args, model, optimizer, step, num_steps): for i in range(num_steps): input_batch = torch.randn([args.batch_size, args.input_features]).cuda() target_batch = torch.randn([args.batch_size, args.output_features]).cuda() step = train_step(args, model, optimizer, input_batch, target_batch, step) return step def main(args): # # PART1 # torch.manual_seed(args.seed) model = build_model(args).cuda() one_ll = next(model.children()).weight optimizer = FusedAdam(model.parameters()) ASP.init_model_for_pruning(model, args.pattern, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask) ASP.init_optimizer_for_pruning(optimizer) step = 0 # train for a few steps with dense weights print("DENSE :: ",one_ll) step = train_loop(args, model, optimizer, step, args.num_dense_steps) # simulate sparsity by inserting zeros into existing dense weights ASP.compute_sparse_masks() # train for a few steps with sparse weights print("SPARSE :: ",one_ll) step = train_loop(args, model, optimizer, step, args.num_sparse_steps) # # PART 2 # torch.manual_seed(args.seed2) # train for a few steps with sparse weights print("SPARSE :: ",one_ll) step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2) if __name__ == '__main__': class Args: seed = 4873 seed2 = 99875 pattern = "m4n2_2d_best" whitelist = [torch.nn.Linear] allow_recompute_mask = True batch_size = 32 input_features = 8 output_features = 8 hidden_features = 32 num_layers = 4 num_dense_steps = 2000 num_sparse_steps = 3000 num_sparse_steps_2 = 1000 checkpoint_path = "part1.chkp" args = Args() main(args)