时间:2021-05-22
pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)
(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)
优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)
备注:
1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"
torch.save(model.state_dict(), PATH)2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用
model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(PATH))model.eval()注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.
模态字典(state_dict)的保存(model是一个网络结构类的对象)
1.1)仅保存学习到的参数,用以下命令
torch.save(model.state_dict(), PATH)1.2)加载model.state_dict,用以下命令
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名
2.1)保存整个model的状态,用以下命令
torch.save(model,PATH)2.2)加载整个model的状态,用以下命令:
# Model class must be defined somewhere model = torch.load(PATH) model.eval()state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项
如何仅加载某一层的训练的到的参数(某一层的state)
If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)
for param in list(model.pretrained.parameters()): param.requires_grad = False注意: requires_grad的操作对象是tensor.
疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False
回答:经测试,不可以.model.conv1 没有requires_grad属性.
全部测试代码:
#-*-coding:utf-8-*-import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optim # define modelclass TheModelClass(nn.Module): def __init__(self): super(TheModelClass,self).__init__() self.conv1 = nn.Conv2d(3,6,5) self.pool = nn.MaxPool2d(2,2) self.conv2 = nn.Conv2d(6,16,5) self.fc1 = nn.Linear(16*5*5,120) self.fc2 = nn.Linear(120,84) self.fc3 = nn.Linear(84,10) def forward(self,x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1,16*5*5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # initial modelmodel = TheModelClass() #initialize the optimizeroptimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9) # print the model's state_dictprint("model's state_dict:")for param_tensor in model.state_dict(): print(param_tensor,'\t',model.state_dict()[param_tensor].size()) print("\noptimizer's state_dict")for var_name in optimizer.state_dict(): print(var_name,'\t',optimizer.state_dict()[var_name]) print("\nprint particular param")print('\n',model.conv1.weight.size())print('\n',model.conv1.weight) print("------------------------------------")torch.save(model.state_dict(),'./model_state_dict.pt')# model_2 = TheModelClass()# model_2.load_state_dict(torch.load('./model_state_dict'))# model.eval()# print('\n',model_2.conv1.weight)# print((model_2.conv1.weight == model.conv1.weight).size())## 仅仅加载某一层的参数conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']print(conv1_weight_state==model.conv1.weight) model_2 = TheModelClass()model_2.load_state_dict(torch.load('./model_state_dict.pt'))model_2.conv1.requires_grad=Falseprint(model_2.conv1.requires_grad)print(model_2.conv1.bias.requires_grad)以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。
在pytorch训练过程中可以通过下面这一句代码来打印当前学习率print(net.optimizer.state_dict()['param_groups']
先说结论model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。应该使用deepcopy(model.state_dict()),或
Android使用selector改变按钮状态实例详解在res/drawable文件夹新增一个文件,此文件设置了图片的触发状态,你可以设置:state_pres
在Python中使用字典,格式如下:dict={key1:value1,key2;value2...}在实际访问字典值时的使用格式如下:dict[key]多键值
一.字典的基本方法1.新建字典1)、建立一个空的字典>>>dict1={}>>>dict2=dict()>>>dict1,dict2({},{})2)、新建的时