博客
关于我
实验——基于pytorch的SRResNet的复现
阅读量:514 次
发布时间:2019-03-07

本文共 39891 字,大约阅读时间需要 132 分钟。

本博文是对xintao前辈的这套代码中的SRResNet做实验时的笔记~代码如链接所示

先给出SRGAN paper的链接()。作者在原文中是这样定义SRResNet的:

“We set a new state of the art for image SR with high upscaling factors (4x) as measured by PSNR and structural similarity (SSIM) with our 16 blocks deep ResNet (SRResNet) optimized for MSE.”

在本人之前的博文《》也介绍过SRGAN(SRResNet),这里不对原理再做过多的描述,有兴趣者建议直接阅读原文。

代码在目录/home/guanwp/BasicSR-master/codes/下,运行以下命令实现train和test

python train.py -opt options/train/train_sr.json

python test.py -opt options/test/test_sr.json

注意下面截图

要训练SRGAN网络,应该是要用。但是本博文展示做的是SRResNet,接下来会有博文对SRGAN做补充(考虑到本人目前还没到可以训练好GAN网络的功力?,先不看SRGAN)

python train.py -opt options/train/train_esrgan.json

 

代码的分析

首先是setting文档(train_sr.json文件)

{  "name": "sr_resnet_baesline"//"001_RRDB_PSNR_x4_DIV2K" //  please remove "debug_" during training or tensorboard wounld not work  , "use_tb_logger": true  , "model":"sr"  , "scale": 4  , "gpu_ids": [1]  , "datasets": {    "train": {      "name": "DIV2K800"      , "mode": "LRHR"      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub"      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4"      , "subset_file": null      , "use_shuffle": true      , "n_workers": 8      , "batch_size": 16//how many samples in each iters      , "HR_size": 192 // 128 | 192      , "use_flip": true      , "use_rot": true    }    , "val": {      "name": "val_set5"      , "mode": "LRHR"      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5"      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4"    }  }  , "path": {    "root": "/home/guanwp/BasicSR-master",    "pretrain_model_G": null     ,"experiments_root": "/home/guanwp/BasicSR-master/experiments/",    "models": "/home/guanwp/BasicSR-master/experiments/sr_resnet_baesline/models",    "log": "/home/guanwp/BasicSR-master/experiments/sr_resnet_baesline",    "val_images": "/home/guanwp/BasicSR-master/experiments/sr_resnet_baesline/val_images"  }  , "network_G": {    "which_model_G": "sr_resnet"//"fsrcnn"//"sr_resnet" // RRDB_net | sr_resnet    , "norm_type": null    , "mode": "CNA"    , "nf": 64//56//64    , "nb": 23    , "in_nc": 3    , "out_nc": 3    , "gc": 32    , "group": 1  }  , "train": {    "lr_G": 1e-3//1e-3//2e-4    , "lr_scheme": "MultiStepLR"    , "lr_steps": [200000,400000,600000,800000,1000000,1500000]    , "lr_gamma": 0.5    , "pixel_criterion": "l1"//"l1"//'l2'//huber//Cross    , "pixel_weight": 1.0    , "val_freq": 5e3    , "manual_seed": 0    , "niter": 2e6//2e6//1e6  }  , "logger": {    "print_freq": 200    , "save_checkpoint_freq": 5e3  }}

PSNR一直在十几,不上去,觉得改变一下setting看看效果

{  "name": "sr_resnet_x4_baesline"//"001_RRDB_PSNR_x4_DIV2K" //  please remove "debug_" during training or tensorboard wounld not work  , "use_tb_logger": true  , "model":"sr"  , "scale": 4  , "gpu_ids": [1]  , "datasets": {    "train": {      "name": "DIV2K800"      , "mode": "LRHR"      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub"      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4"      , "subset_file": null      , "use_shuffle": true      , "n_workers": 8      , "batch_size": 16//how many samples in each iters      , "HR_size": 128 // 128 | 192      , "use_flip": true      , "use_rot": true    }    , "val": {      "name": "val_set5"      , "mode": "LRHR"      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5"      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4"    }  }  , "path": {    "root": "/home/guanwp/BasicSR-master",    "pretrain_model_G": null     ,"experiments_root": "/home/guanwp/BasicSR-master/experiments/",    "models": "/home/guanwp/BasicSR-master/experiments/sr_resnet_baesline/models",    "log": "/home/guanwp/BasicSR-master/experiments/sr_resnet_baesline",    "val_images": "/home/guanwp/BasicSR-master/experiments/sr_resnet_baesline/val_images"  }  , "network_G": {    "which_model_G": "sr_resnet"//"fsrcnn"//"sr_resnet" // RRDB_net | sr_resnet    , "norm_type": null    , "mode": "CNA"    , "nf": 64//56//64    , "nb": 23    , "in_nc": 3    , "out_nc": 3    , "gc": 32    , "group": 1  }  , "train": {    "lr_G": 2e-4//1e-3//2e-4    , "lr_scheme": "MultiStepLR"    , "lr_steps": [200000,400000,600000,800000,1000000,1500000]    , "lr_gamma": 0.5    , "pixel_criterion": "l1"//"l1"//'l2'//huber//Cross    , "pixel_weight": 1.0    , "val_freq": 5e3    , "manual_seed": 0    , "niter": 2e6//2e6//1e6  }  , "logger": {    "print_freq": 200    , "save_checkpoint_freq": 5e3  }}
# Generatordef define_G(opt):    gpu_ids = opt['gpu_ids']    opt_net = opt['network_G']    which_model = opt_net['which_model_G']#hear decide which model, and thia para is in .json. if you add a new model, this part must be modified    if which_model == 'sr_resnet':  # SRResNet        netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')#############################################################################################################    elif which_model=='fsrcnn':#FSRCNN        netG=arch.FSRCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')#############################################################################################################    elif which_model == 'sft_arch':  # SFT-GAN        netG = sft_arch.SFT_Net()    elif which_model == 'RRDB_net':  # RRDB,this is ESRGAN        netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],            nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],            act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')    else:        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))    if opt['is_train']:        init_weights(netG, init_type='kaiming', scale=0.1)###the weight initing. you can change this to change the method of init_weight    if gpu_ids:        assert torch.cuda.is_available()        netG = nn.DataParallel(netG)    return netG

在architecture.py文件中有SRResNst的结构

#####################SRResNet########################################################class SRResNet(nn.Module):#read my CSDN for the nn.Module#nn.Module is contain the forward and each layyer    def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='batch', act_type='relu', \            mode='NAC', res_scale=1, upsample_mode='upconv'):#the .jason file decide the mode is "CNA"        #input channels\output channels\the number of filters in the first layer\thw number of resduial block\upscale\   \relu\Conv -> Norm -> Act\        super(SRResNet, self).__init__()#for the super(),read my CSDN        n_upscale = int(math.log(upscale, 2))        if upscale == 3:            n_upscale = 1        fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)#read the block.py.        resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,\            mode=mode, res_scale=res_scale) for _ in range(nb)]#'nb' is the number of block, and there is 23 in the .jason        LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)#here use the BN        if upsample_mode == 'upconv':            upsample_block = B.upconv_blcok##Deconvolution        elif upsample_mode == 'pixelshuffle':##there are 'pixelshuffle' in the network.py            upsample_block = B.pixelshuffle_block##the espcn        else:            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))        if upscale == 3:            upsampler = upsample_block(nf, nf, 3, act_type=act_type)        else:            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)        self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)),\            *upsampler, HR_conv0, HR_conv1)    def forward(self, x):        x = self.model(x)        return x

block.py

from collections import OrderedDictimport torchimport torch.nn as nn##################### Basic blocks####################def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):    # helper selecting activation    # neg_slope: for leakyrelu and init of prelu    # n_prelu: for p_relu num_parameters    act_type = act_type.lower()    if act_type == 'relu':        layer = nn.ReLU(inplace)    elif act_type == 'leakyrelu':        layer = nn.LeakyReLU(neg_slope, inplace)    elif act_type == 'prelu':        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)    else:        raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))    return layerdef norm(norm_type, nc):######################this is the part of normalization    # helper selecting normalization layer    norm_type = norm_type.lower()    if norm_type == 'batch':        layer = nn.BatchNorm2d(nc, affine=True)    elif norm_type == 'instance':        layer = nn.InstanceNorm2d(nc, affine=False)    else:        raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))    return layerdef pad(pad_type, padding):    # helper selecting padding layer    # if padding is 'zero', do by conv layers    pad_type = pad_type.lower()    if padding == 0:        return None    if pad_type == 'reflect':        layer = nn.ReflectionPad2d(padding)    elif pad_type == 'replicate':        layer = nn.ReplicationPad2d(padding)    else:        raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))    return layerdef get_valid_padding(kernel_size, dilation):    kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)    padding = (kernel_size - 1) // 2    return paddingclass ConcatBlock(nn.Module):    # Concat the output of a submodule to its input    def __init__(self, submodule):        super(ConcatBlock, self).__init__()        self.sub = submodule    def forward(self, x):        output = torch.cat((x, self.sub(x)), dim=1)        return output    def __repr__(self):        tmpstr = 'Identity .. \n|'        modstr = self.sub.__repr__().replace('\n', '\n|')        tmpstr = tmpstr + modstr        return tmpstrclass ShortcutBlock(nn.Module):    #Elementwise sum the output of a submodule to its input    def __init__(self, submodule):        super(ShortcutBlock, self).__init__()        self.sub = submodule    def forward(self, x):        output = x + self.sub(x)        return output    def __repr__(self):        tmpstr = 'Identity + \n|'        modstr = self.sub.__repr__().replace('\n', '\n|')        tmpstr = tmpstr + modstr        return tmpstrdef sequential(*args):    # Flatten Sequential. It unwraps nn.Sequential.    if len(args) == 1:        if isinstance(args[0], OrderedDict):            raise NotImplementedError('sequential does not support OrderedDict input.')        return args[0]  # No sequential is needed.    modules = []    for module in args:        if isinstance(module, nn.Sequential):            for submodule in module.children():                modules.append(submodule)        elif isinstance(module, nn.Module):            modules.append(module)    return nn.Sequential(*modules)def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \               pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):    '''    Conv layer with padding, normalization, activation    mode: CNA --> Conv -> Norm -> Act        NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)    '''    assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)    padding = get_valid_padding(kernel_size, dilation)    p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None    padding = padding if pad_type == 'zero' else 0    c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \            dilation=dilation, bias=bias, groups=groups)    a = act(act_type) if act_type else None    if 'CNA' in mode:#this was used in RESNET        n = norm(norm_type, out_nc) if norm_type else None#this is the setting of normalization, and the normalization of SRResNet is BN        return sequential(p, c, n, a)#padding, conv, normalization, active    elif mode == 'NAC':        if norm_type is None and act_type is not None:            a = act(act_type, inplace=False)            # Important!            # input----ReLU(inplace)----Conv--+----output            #        |________________________|            # inplace ReLU will modify the input, therefore wrong output        n = norm(norm_type, in_nc) if norm_type else None        return sequential(n, a, p, c)##################### Useful blocks####################class ResNetBlock(nn.Module):    '''    ResNet Block, 3-3 style    with extra residual scaling used in EDSR    (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)    '''    def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \            bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):        super(ResNetBlock, self).__init__()        conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \            norm_type, act_type, mode)        if mode == 'CNA':            act_type = None        if mode == 'CNAC':  # Residual path: |-CNAC-|            act_type = None            norm_type = None        conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \            norm_type, act_type, mode)        # if in_nc != out_nc:        #     self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \        #         None, None)        #     print('Need a projecter in ResNetBlock.')        # else:        #     self.project = lambda x:x        self.res = sequential(conv0, conv1)        self.res_scale = res_scale    def forward(self, x):        res = self.res(x).mul(self.res_scale)        return x + resclass ResidualDenseBlock_5C(nn.Module):    '''    Residual Dense Block    style: 5 convs    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)    '''    def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \            norm_type=None, act_type='leakyrelu', mode='CNA'):        super(ResidualDenseBlock_5C, self).__init__()        # gc: growth channel, i.e. intermediate channels        self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \            norm_type=norm_type, act_type=act_type, mode=mode)        self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \            norm_type=norm_type, act_type=act_type, mode=mode)        self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \            norm_type=norm_type, act_type=act_type, mode=mode)        self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \            norm_type=norm_type, act_type=act_type, mode=mode)        if mode == 'CNA':            last_act = None        else:            last_act = act_type        self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \            norm_type=norm_type, act_type=last_act, mode=mode)    def forward(self, x):        x1 = self.conv1(x)        x2 = self.conv2(torch.cat((x, x1), 1))        x3 = self.conv3(torch.cat((x, x1, x2), 1))        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))        return x5.mul(0.2) + xclass RRDB(nn.Module):    '''    Residual in Residual Dense Block    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)    '''    def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \            norm_type=None, act_type='leakyrelu', mode='CNA'):        super(RRDB, self).__init__()        self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \            norm_type, act_type, mode)        self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \            norm_type, act_type, mode)        self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \            norm_type, act_type, mode)    def forward(self, x):        out = self.RDB1(x)        out = self.RDB2(out)        out = self.RDB3(out)        return out.mul(0.2) + x##################### Upsampler####################def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \                        pad_type='zero', norm_type=None, act_type='relu'):    '''    Pixel shuffle layer    (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional    Neural Network, CVPR17)    '''    conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \                        pad_type=pad_type, norm_type=None, act_type=None)    pixel_shuffle = nn.PixelShuffle(upscale_factor)    n = norm(norm_type, out_nc) if norm_type else None    a = act(act_type) if act_type else None    return sequential(conv, pixel_shuffle, n, a)def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \                pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):    # Up conv    # described in https://distill.pub/2016/deconv-checkerboard/    upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)    conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \                        pad_type=pad_type, norm_type=norm_type, act_type=act_type)    return sequential(upsample, conv)

整体的网络结构如下:

-------------- Generator --------------SRResNet(  (model): Sequential(    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (1): Identity +     |Sequential(    |  (0): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (1): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (2): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (3): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (4): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (5): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (6): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (7): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (8): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (9): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (10): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (11): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (12): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (13): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (14): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (15): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (16): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (17): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (18): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (19): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (20): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (21): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (22): ResNetBlock(    |    (res): Sequential(    |      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |      (1): ReLU(inplace)    |      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |    )    |  )    |  (23): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    |)    (2): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (3): PixelShuffle(upscale_factor=2)    (4): ReLU(inplace)    (5): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (6): PixelShuffle(upscale_factor=2)    (7): ReLU(inplace)    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))    (9): ReLU(inplace)    (10): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))  ))

原文的结果

SRResNet为32.05dB。而FSRCNN在x4时,为30.55dB(由于我之前训练时,一个多小时就可以达到原文的效果,所以没有继续训练下去了)

训练结果

结果如下图所示,包括了原实现(虽然没训练完,但是效果已经比原文要好了)

 

 

 

SRResNet的原文实现

(应该是原文实现,上面的采用了23个block,原文写道采用16个block)

()

参考上面代码,将关键部分重塑如下:

network.py

# Generatordef define_G(opt):    gpu_ids = opt['gpu_ids']    opt_net = opt['network_G']    which_model = opt_net['which_model_G']#hear decide which model, and thia para is in .json. if you add a new model, this part must be modified    if which_model == 'sr_resnet':  # SRResNet        netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')#############################################################################################################    elif which_model=='fsrcnn':#FSRCNN        netG=arch.FSRCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')##########################################################################################################################################################################################################################    elif which_model=='srresnet':#SRResNet, the Original version        netG=arch.OSRRESNET(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \            nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \            act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')#############################################################################################################    elif which_model == 'sft_arch':  # SFT-GAN        netG = sft_arch.SFT_Net()    elif which_model == 'RRDB_net':  # RRDB,this is ESRGAN        netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],            nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],            act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')    else:        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))    if opt['is_train']:        init_weights(netG, init_type='kaiming', scale=0.1)###the weight initing. you can change this to change the method of init_weight    if gpu_ids:        assert torch.cuda.is_available()        netG = nn.DataParallel(netG)    return netG

在architecture.py中

###########################################################################################################SRResNet, the Original version#define the residual blockclass O_Residual_Block(nn.Module):    def __init__(self):        super(O_Residual_Block,self).__init__()        self.conv1=nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)        self.bn1=nn.BatchNorm2d(64, affine=True)        self.prelu=nn.PReLU()        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)        self.bn2=nn.BatchNorm2d(64, affine=True)    def forward(self, x):        identity_data = x        output = self.prelu(self.bn1(self.conv1(x)))        output = self.bn2(self.conv2(output))        output = torch.add(output,identity_data)        return output ##############################################class OSRRESNET(nn.Module):    def __init__(self, in_nc, out_nc, nf, nb, upscale=2, norm_type='batch', act_type='relu', \            mode='NAC', res_scale=1, upsample_mode='upconv'):##play attention the upscales        super(OSRRESNET,self).__init__()        self.conv_input=nn.Conv2d(in_channels=in_nc,out_channels=nf,kernel_size=9,stride=1,padding=4,bias=False)        self.prelu=nn.PReLU()        self.residual=self.make_layer(O_Residual_Block,16)        self.conv_mid = nn.Conv2d(in_channels=nf, out_channels=nf, kernel_size=3, stride=1, padding=1, bias=False)        self.bn_mid=nn.BatchNorm2d(64, affine=True)        self.upscale4x = nn.Sequential(            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),            nn.PixelShuffle(2),            nn.PReLU(),            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),            nn.PixelShuffle(2),            nn.PReLU(),        )        self.conv_output = nn.Conv2d(in_channels=nf, out_channels=out_nc, kernel_size=9, stride=1, padding=4, bias=False)    def make_layer(self, block, num_of_layer):        layers = []        for _ in range(num_of_layer):            layers.append(block())        return nn.Sequential(*layers)     def forward(self,x):        out = self.prelu(self.conv_input(x))        residual = out        out = self.residual(out)        out = self.bn_mid(self.conv_mid(out))        out = torch.add(out,residual)        out = self.upscale4x(out)        out = self.conv_output(out)        return out  ##########################################################################################################################################################

setting如下

{  "name": "or_sr_resnet_x4"//"001_RRDB_PSNR_x4_DIV2K" //  please remove "debug_" during training or tensorboard wounld not work  , "use_tb_logger": true  , "model":"sr"  , "scale": 4  , "gpu_ids": [2]  , "datasets": {    "train": {      "name": "DIV2K800"      , "mode": "LRHR"      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub"      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4"      , "subset_file": null      , "use_shuffle": true      , "n_workers": 8      , "batch_size": 16//how many samples in each iters      , "HR_size": 128 // 128 | 192      , "use_flip": true      , "use_rot": true    }    , "val": {      "name": "val_set5"      , "mode": "LRHR"      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5"      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4"    }  }  , "path": {    "root": "/home/guanwp/BasicSR-master",    "pretrain_model_G": null     ,"experiments_root": "/home/guanwp/BasicSR-master/experiments/",    "models": "/home/guanwp/BasicSR-master/experiments/or_sr_resnet_x4/models",    "log": "/home/guanwp/BasicSR-master/experiments/or_sr_resnet_x4",    "val_images": "/home/guanwp/BasicSR-master/experiments/or_sr_resnet_x4/val_images"  }  , "network_G": {    "which_model_G": "srresnet"//"sr_resnet"//"fsrcnn"//"sr_resnet" // RRDB_net | sr_resnet    , "norm_type": null    , "mode": "CNA"    , "nf": 64//56//64    , "nb": 23    , "in_nc": 3    , "out_nc": 3    , "gc": 32    , "group": 1  }  , "train": {    "lr_G": 2e-4//1e-3//2e-4    , "lr_scheme": "MultiStepLR"    , "lr_steps": [200000,400000,600000,800000,1000000,1500000]    , "lr_gamma": 0.5    , "pixel_criterion": "l1"//"l1"//'l2'//huber//Cross    , "pixel_weight": 1.0    , "val_freq": 5e3    , "manual_seed": 0    , "niter": 2e6//2e6//1e6  }  , "logger": {    "print_freq": 200    , "save_checkpoint_freq": 5e3  }}

网络结构如下

-------------- Generator --------------OSRRESNET(  (conv_input): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False)  (prelu): PReLU(num_parameters=1)  (residual): Sequential(    (0): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (1): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (2): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (3): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (4): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (5): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (6): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (7): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (8): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (9): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (10): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (11): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (12): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (13): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (14): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )    (15): O_Residual_Block(      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (prelu): PReLU(num_parameters=1)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    )  )  (conv_mid): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)  (bn_mid): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)  (upscale4x): Sequential(    (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)    (1): PixelShuffle(upscale_factor=2)    (2): PReLU(num_parameters=1)    (3): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)    (4): PixelShuffle(upscale_factor=2)    (5): PReLU(num_parameters=1)  )  (conv_output): Conv2d(64, 3, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False))

 

 

 

补充:

关于nn.LeakyReLU

关于MOS test

给出原文的supplementary file截图如下,由于本人觉得NIQE更加有代表性,所以就不对MOS做深入的描述。接下来会有博客描述NIQE

 

 

 

 

 

转载地址:http://xiajz.baihongyu.com/

你可能感兴趣的文章
Netty高性能原理和框架架构解析
查看>>
Java基础之反射
查看>>
对象的创建、内存布局和访问定位
查看>>
FreeRTOS学习笔记(9)——内存管理
查看>>
CC2640R2F学习笔记(6)——UART串口使用
查看>>
SHELL命令
查看>>
自然划分的3-4-5规则
查看>>
Latex中cases环境引入报错
查看>>
Latex排版的时候把图片放在指定位置
查看>>
Nginx出现500 Internal Server Error 错误
查看>>
MySQL中drop、truncate和delete的区别?
查看>>
Mysql索引底层B+树的实现原理以及Innodb和Myisam引擎存储的区别
查看>>
09-01 Java语言基础(package、import)
查看>>
11-01 Java语言基础(Scanner类)
查看>>
Accessing Excel Spreadsheets via C++
查看>>
json.parse细节
查看>>
redis
查看>>
ReID基础 | ReID工程中的一些小trick
查看>>
OpenCV6边缘检测[Canny算法]
查看>>
Hadoop_Scala操作Hbase
查看>>