# import torch # import torchvision # from models import basicblock as B def show_kv(net): for k, v in net.items(): print(k) # should run train debug mode first to get an initial model # crt_net = torch.load('../../experiments/debug_SRResNet_bicx4_in3nf64nb16/models/8_G.pth') # # for k, v in crt_net.items(): # print(k) # for k, v in crt_net.items(): # if k in pretrained_net: # crt_net[k] = pretrained_net[k] # print('replace ... ', k) # x2 -> x4 # crt_net['model.5.weight'] = pretrained_net['model.2.weight'] # crt_net['model.5.bias'] = pretrained_net['model.2.bias'] # crt_net['model.8.weight'] = pretrained_net['model.5.weight'] # crt_net['model.8.bias'] = pretrained_net['model.5.bias'] # crt_net['model.10.weight'] = pretrained_net['model.7.weight'] # crt_net['model.10.bias'] = pretrained_net['model.7.bias'] # torch.save(crt_net, '../pretrained_tmp.pth') # x2 -> x3 ''' in_filter = pretrained_net['model.2.weight'] # 256, 64, 3, 3 new_filter = torch.Tensor(576, 64, 3, 3) new_filter[0:256, :, :, :] = in_filter new_filter[256:512, :, :, :] = in_filter new_filter[512:, :, :, :] = in_filter[0:576-512, :, :, :] crt_net['model.2.weight'] = new_filter in_bias = pretrained_net['model.2.bias'] # 256, 64, 3, 3 new_bias = torch.Tensor(576) new_bias[0:256] = in_bias new_bias[256:512] = in_bias new_bias[512:] = in_bias[0:576 - 512] crt_net['model.2.bias'] = new_bias torch.save(crt_net, '../pretrained_tmp.pth') ''' # x2 -> x8 ''' crt_net['model.5.weight'] = pretrained_net['model.2.weight'] crt_net['model.5.bias'] = pretrained_net['model.2.bias'] crt_net['model.8.weight'] = pretrained_net['model.2.weight'] crt_net['model.8.bias'] = pretrained_net['model.2.bias'] crt_net['model.11.weight'] = pretrained_net['model.5.weight'] crt_net['model.11.bias'] = pretrained_net['model.5.bias'] crt_net['model.13.weight'] = pretrained_net['model.7.weight'] crt_net['model.13.bias'] = pretrained_net['model.7.bias'] torch.save(crt_net, '../pretrained_tmp.pth') ''' # x3/4/8 RGB -> Y def rgb2gray_net(net, only_input=True): if only_input: in_filter = net['0.weight'] in_new_filter = in_filter[:, 0, :, :] * 0.2989 + in_filter[:, 1, :, :] * 0.587 + in_filter[:, 2, :, :] * 0.114 in_new_filter.unsqueeze_(1) net['0.weight'] = in_new_filter # out_filter = pretrained_net['model.13.weight'] # out_new_filter = out_filter[0, :, :, :] * 0.2989 + out_filter[1, :, :, :] * 0.587 + \ # out_filter[2, :, :, :] * 0.114 # out_new_filter.unsqueeze_(0) # crt_net['model.13.weight'] = out_new_filter # out_bias = pretrained_net['model.13.bias'] # out_new_bias = out_bias[0] * 0.2989 + out_bias[1] * 0.587 + out_bias[2] * 0.114 # out_new_bias = torch.Tensor(1).fill_(out_new_bias) # crt_net['model.13.bias'] = out_new_bias # torch.save(crt_net, '../pretrained_tmp.pth') return net # if __name__ == '__main__': # net = torchvision.models.vgg19(pretrained=True) # for k, v in net.features.named_parameters(): # if k == '0.weight': # in_new_filter = v[:, 0, :, :] * 0.2989 + \ # v[:, 1, :, :] * 0.587 + v[:, 2, :, :] * 0.114 # in_new_filter.unsqueeze_(1) # v = in_new_filter # print(v.shape) # print(v[0, 0, 0, 0]) # if k == '0.bias': # in_new_bias = v # print(v[0]) # print(net.features[0]) # net.features[0] = B.conv(1, 64, mode='C') # print(net.features[0]) # net.features[0].weight.data = in_new_filter # net.features[0].bias.data = in_new_bias # for k, v in net.features.named_parameters(): # if k == '0.weight': # print(v[0, 0, 0, 0]) # if k == '0.bias': # print(v[0]) # # transfer parameters of old model to new one # model_old = torch.load(model_path) # state_dict = model.state_dict() # for ((key, param), (key2, param2)) in zip( # model_old.items(), state_dict.items()): # state_dict[key2] = param # print([key, key2]) # # print([param.size(), param2.size()]) # torch.save(state_dict, 'model_new.pth') # # rgb2gray_net(net)