时间:2021-05-22
工作中有时候需要对vgg进行定制化处理,比如有些时候需要借助于vgg的层结构,但是需要使用的是2 channels输入,等等需求,这时候可以使用vgg的原始结构用class重写一遍,但是这样的方式比较慢,并且容易出错,下面给出一种比较简单的方式
def define_vgg(vgg,input_channels,endlayer,use_maxpool=False): vgg_ad = copy.deepcopy(vgg) model = nn.Sequential() i = 0 for layer in list(vgg_ad.features): if i > endlayer: break if isinstance(layer, nn.Conv2d) and i is 0: name = "conv_" + str(i) layer = nn.Conv2d(input_channels, layer.out_channels, layer.kernel_size, stride = layer.stride, padding=layer.padding) model.add_module(name, layer) if isinstance(layer, nn.Conv2d): name = "conv_" + str(i) model.add_module(name, layer) if isinstance(layer, nn.ReLU): name = "leakyrelu_" + str(i) layer = nn.LeakyReLU(inplace=True) model.add_module(name, layer) if isinstance(layer, nn.MaxPool2d): name = "pool_" + str(i) if use_maxpool: model.add_module(name, layer) else: avgpool = nn.AvgPool2d(kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding) model.add_module(name, avgpool) i += 1 return model函数输入项中的vgg 是直接使用的import torchvision.models.vgg16 传入的是vgg16 非预训练版本。end_layer 是需要提取的层数,这里使用了vgg.features 是指仅仅在vgg.features 上进行层的提取;也可以根据定制在classifier上进行提取。
下面是我的一个提取前7层的示例,可以使用pyCharm evaluate 上面函数返回的model,可以看到这个示例的情况,这里我的定制条件是输入通道为2 ,需要提取前7层,并且将ReLu更换为LeakyRelu。
Sequential( (conv_0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (leakyrelu_1): LeakyReLU(negative_slope=0.01, inplace) (conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (leakyrelu_3): LeakyReLU(negative_slope=0.01, inplace) (pool_4): AvgPool2d(kernel_size=2, stride=2, padding=0) (conv_5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (leakyrelu_6): LeakyReLU(negative_slope=0.01, inplace) (conv_7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))以上这篇Pytorch 抽取vgg各层并进行定制化处理的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。
模型VGG,数据集cifar。对照这份代码走一遍,大概就知道整个pytorch的运行机制。来源定义模型:'''VGG11/13/16/19inPytorch.'
最近使用pytorch时,需要用到一个预训练好的人脸识别模型提取人脸ID特征,想到很多人都在用用vgg-face,但是vgg-face没有pytorch的模型,
正在看的db2教程是:用shell抽取,更新db2的数据。为工作需要而写的shell处理db2数据库的程序用shell抽取db2的数据,并进行处理。#SQL文定
一、Reference类型(除强引用)可以理解为Reference的直接子类都是由jvm定制化处理的,因此在代码中直接继承于Reference类型没有任何作用.
目标:优化代码,利用多进程,进行近实时预处理、网络预测及后处理:本人尝试了pytorch的multiprocessing,进行多进程同步处理以上任务。fromt