时间:2021-05-22
在TensorFlow中,权重的初始化主要是在声明张量的时候进行的。 而PyTorch则提供了另一种方法:首先应该声明张量,然后修改张量的权重。通过调用torch.nn.init包中的多种方法可以将权重初始化为直接访问张量的属性。
1、不初始化的效果
在Pytorch中,定义一个tensor,不进行初始化,打印看看结果:
w = torch.Tensor(3,4)print (w)可以看到这时候的初始化的数值都是随机的,而且特别大,这对网络的训练必定不好,最后导致精度提不上,甚至损失无法收敛。
2、初始化的效果
PyTorch提供了多种参数初始化函数:
torch.nn.init.constant(tensor, val)torch.nn.init.normal(tensor, mean=0, std=1)torch.nn.init.xavier_uniform(tensor, gain=1)等等。详细请参考:http://pytorch.org/docs/nn.html#torch-nn-init
注意上面的初始化函数的参数tensor,虽然写的是tensor,但是也可以是Variable类型的。而神经网络的参数类型Parameter是Variable类的子类,所以初始化函数可以直接作用于神经网络参数。实际上,我们初始化也是直接去初始化神经网络的参数。
让我们试试效果:
w = torch.Tensor(3,4)torch.nn.init.normal_(w)print (w)3、初始化神经网络的参数
对神经网络的初始化往往放在模型的__init__()函数中,如下所示:
class Net(nn.Module):
def __init__(self, block, layers, num_classes=1000): self.inplanes = 64 super(Net, self).__init__() *** *** #定义自己的网络层 *** for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_()****** #定义后续的函数***也可以采取另一种方式:
定义一个权重初始化函数,如下:
def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: init.xavier_normal_(m.weight.data) init.constant_(m.bias.data, 0.0) elif classname.find('Linear') != -1: init.xavier_normal_(m.weight.data) init.constant_(m.bias.data, 0.0)在模型声明时,调用初始化函数,初始化神经网络参数:
model = Net(*****)model.apply(weights_init)以上这篇Pytorch 实现权重初始化就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。
权重初始化对于训练神经网络至关重要,好的初始化权重可以有效的避免梯度消失等问题的发生。在pytorch的使用过程中有几种权重初始化的方法供大家参考。注意:第一种
在常见的pytorch代码中,我们见到的初始化方式都是调用init类对每层所有参数进行初始化。但是,有时我们有些特殊需求,比如用某一层的权重取优化其它层,或者手
pytorch实现线性回归代码练习实例,供大家参考,具体内容如下欢迎大家指正,希望可以通过小的练习提升对于pytorch的掌握#随机初始化一个二维数据集,使用朋
在神经网络训练中,好的权重初始化会加速训练过程。下面说一下kernel_initializer权重初始化的方法。不同的层可能使用不同的关键字来传递初始化方法,一
路径:https://pytorch.org/docs/master/nn.init.html#nn-init-doc初始化函数:torch.nn.init#-