from collections import OrderedDict import torch from apex.optimizers import FusedAdam from apex.contrib.sparsity import ASP def build_model(args): od = OrderedDict() for i in range(args.num_layers): if i == 0: od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features) od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features]) elif i == args.num_layers-1: od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features) od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features]) else: od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features) od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features]) return torch.nn.Sequential(od) def train_step(args, model, optimizer, input_batch, target_batch, step): predicted_target = model(input_batch) loss = ((predicted_target-target_batch)**2).sum() loss.backward() optimizer.step() optimizer.zero_grad() step = step + 1 #print("Step %d :: loss=%e" % (step, loss.item())) return step def train_loop(args, model, optimizer, step, num_steps): for i in range(num_steps): input_batch = torch.randn([args.batch_size, args.input_features]).cuda() target_batch = torch.randn([args.batch_size, args.output_features]).cuda() step = train_step(args, model, optimizer, input_batch, target_batch, step) return step def main(step, args, model_state_dict, optimizer_state_dict): # # PART2 # model = build_model(args).cuda() one_ll = next(model.children()).weight optimizer = FusedAdam(model.parameters()) ASP.init_model_for_pruning(model, args.pattern, verbosity=args.verbosity, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask) ASP.init_optimizer_for_pruning(optimizer) torch.manual_seed(args.seed2) model.load_state_dict(model_state_dict) optimizer.load_state_dict(optimizer_state_dict) print("Model sparsity is %s" % ("enabled" if ASP.is_sparsity_enabled() else "disabled")) # train for a few steps with sparse weights print("SPARSE :: ",one_ll) step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2) if __name__ == '__main__': checkpoint = torch.load("part1.chkp") class Args: verbosity = checkpoint['verbosity'] seed = 4873 seed2 = checkpoint['seed2'] pattern = checkpoint['pattern'] whitelist = checkpoint['whitelist'] allow_recompute_mask = checkpoint['allow_recompute_mask'] batch_size = 32 input_features = 8 output_features = 8 hidden_features = 32 num_layers = 4 num_dense_steps = 2000 num_sparse_steps = 3000 num_sparse_steps_2 = 1000 checkpoint_path = "part1.chkp" args = Args() main(checkpoint['step'], args, checkpoint['model_state_dict'], checkpoint['optimizer_state_dict'])