# -*- coding: utf-8 -*- import numpy as np import torch from utils import utils_image as util import re import glob import os ''' # -------------------------------------------- # Model # -------------------------------------------- # Kai Zhang (github: https://github.com/cszn) # 03/Mar/2019 # -------------------------------------------- ''' def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None): """ # --------------------------------------- # Kai Zhang (github: https://github.com/cszn) # 03/Mar/2019 # --------------------------------------- Args: save_dir: model folder net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD' pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path Return: init_iter: iteration number init_path: model path # --------------------------------------- """ file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) if file_list: iter_exist = [] for file_ in file_list: iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) iter_exist.append(int(iter_current[0])) init_iter = max(iter_exist) init_path = os.path.join( save_dir, '{}_{}.pth'.format( init_iter, net_type)) else: init_iter = 0 init_path = pretrained_path return init_iter, init_path def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1): ''' # --------------------------------------- # Kai Zhang (github: https://github.com/cszn) # 03/Mar/2019 # --------------------------------------- Args: model: trained model L: input Low-quality image mode: (0) normal: test(model, L) (1) pad: test_pad(model, L, modulo=16) (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1) (3) x8: test_x8(model, L, modulo=1) ^_^ (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1) refield: effective receptive filed of the network, 32 is enough useful when split, i.e., mode=2, 4 min_size: min_sizeXmin_size image, e.g., 256X256 image useful when split, i.e., mode=2, 4 sf: scale factor for super-resolution, otherwise 1 modulo: 1 if split useful when pad, i.e., mode=1 Returns: E: estimated image # --------------------------------------- ''' if mode == 0: E = test(model, L) elif mode == 1: E = test_pad(model, L, modulo, sf) elif mode == 2: E = test_split(model, L, refield, min_size, sf, modulo) elif mode == 3: E = test_x8(model, L, modulo, sf) elif mode == 4: E = test_split_x8(model, L, refield, min_size, sf, modulo) return E ''' # -------------------------------------------- # normal (0) # -------------------------------------------- ''' def test(model, L): E = model(L) return E ''' # -------------------------------------------- # pad (1) # -------------------------------------------- ''' def test_pad(model, L, modulo=16, sf=1): h, w = L.size()[-2:] paddingBottom = int(np.ceil(h / modulo) * modulo - h) paddingRight = int(np.ceil(w / modulo) * modulo - w) L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L) E = model(L) E = E[..., :h * sf, :w * sf] return E ''' # -------------------------------------------- # split (function) # -------------------------------------------- ''' def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1): """ Args: model: trained model L: input Low-quality image refield: effective receptive filed of the network, 32 is enough min_size: min_sizeXmin_size image, e.g., 256X256 image sf: scale factor for super-resolution, otherwise 1 modulo: 1 if split Returns: E: estimated result """ h, w = L.size()[-2:] if h * w <= min_size**2: L = torch.nn.ReplicationPad2d((0, int(np.ceil(w / modulo) * modulo - w), 0, int(np.ceil(h / modulo) * modulo - h)))(L) E = model(L) E = E[..., :h * sf, :w * sf] else: top = slice(0, (h // 2 // refield + 1) * refield) bottom = slice(h - (h // 2 // refield + 1) * refield, h) left = slice(0, (w // 2 // refield + 1) * refield) right = slice(w - (w // 2 // refield + 1) * refield, w) Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]] if h * w <= 4 * (min_size**2): Es = [model(Ls[i]) for i in range(4)] else: Es = [ test_split_fn( model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)] b, c = Es[0].size()[:2] E = torch.zeros(b, c, sf * h, sf * w).type_as(L) E[..., :h // 2 * sf, :w // 2 * sf] = Es[0][..., :h // 2 * sf, :w // 2 * sf] E[..., :h // 2 * sf, w // 2 * sf:w * sf] = Es[1][..., :h // 2 * sf, (-w + w // 2) * sf:] E[..., h // 2 * sf:h * sf, :w // 2 * sf] = Es[2][..., (-h + h // 2) * sf:, :w // 2 * sf] E[..., h // 2 * sf:h * sf, w // 2 * sf:w * sf] = Es[3][..., (-h + h // 2) * sf:, (-w + w // 2) * sf:] return E ''' # -------------------------------------------- # split (2) # -------------------------------------------- ''' def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1): E = test_split_fn( model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo) return E ''' # -------------------------------------------- # x8 (3) # -------------------------------------------- ''' def test_x8(model, L, modulo=1, sf=1): E_list = [ test_pad( model, util.augment_img_tensor4( L, mode=i), modulo=modulo, sf=sf) for i in range(8)] for i in range(len(E_list)): if i == 3 or i == 5: E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i) else: E_list[i] = util.augment_img_tensor4(E_list[i], mode=i) output_cat = torch.stack(E_list, dim=0) E = output_cat.mean(dim=0, keepdim=False) return E ''' # -------------------------------------------- # split and x8 (4) # -------------------------------------------- ''' def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1): E_list = [ test_split_fn( model, util.augment_img_tensor4( L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)] for k, i in enumerate(range(len(E_list))): if i == 3 or i == 5: E_list[k] = util.augment_img_tensor4(E_list[k], mode=8 - i) else: E_list[k] = util.augment_img_tensor4(E_list[k], mode=i) output_cat = torch.stack(E_list, dim=0) E = output_cat.mean(dim=0, keepdim=False) return E ''' # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^- # _^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^ # ^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^-^_^- ''' ''' # -------------------------------------------- # print # -------------------------------------------- ''' # -------------------------------------------- # print model # -------------------------------------------- def print_model(model): msg = describe_model(model) print(msg) # -------------------------------------------- # print params # -------------------------------------------- def print_params(model): msg = describe_params(model) print(msg) ''' # -------------------------------------------- # information # -------------------------------------------- ''' # -------------------------------------------- # model inforation # -------------------------------------------- def info_model(model): msg = describe_model(model) return msg # -------------------------------------------- # params inforation # -------------------------------------------- def info_params(model): msg = describe_params(model) return msg ''' # -------------------------------------------- # description # -------------------------------------------- ''' # -------------------------------------------- # model name and total number of parameters # -------------------------------------------- def describe_model(model): if isinstance(model, torch.nn.DataParallel): model = model.module msg = '\n' msg += 'models name: {}'.format(model.__class__.__name__) + '\n' msg += 'Params number: {}'.format( sum(map(lambda x: x.numel(), model.parameters()))) + '\n' msg += 'Net structure:\n{}'.format(str(model)) + '\n' return msg # -------------------------------------------- # parameters description # -------------------------------------------- def describe_params(model): if isinstance(model, torch.nn.DataParallel): model = model.module msg = '\n' msg += '{:^6s} | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format( 'mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n' for name, param in model.state_dict().items(): if 'num_batches_tracked' not in name: v = param.data.clone().float() msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format( v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n' return msg if __name__ == '__main__': class Net(torch.nn.Module): def __init__(self, in_channels=3, out_channels=3): super(Net, self).__init__() self.conv = torch.nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) def forward(self, x): x = self.conv(x) return x start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) model = Net() model = model.eval() print_model(model) print_params(model) x = torch.randn((2, 3, 401, 401)) torch.cuda.empty_cache() with torch.no_grad(): for mode in range(5): y = test_mode( model, x, mode, refield=32, min_size=256, sf=1, modulo=1) print(y.shape) # run utils/utils_model.py