如何解析Pytorch基础中网络参数初始化问题
如何解析Pytorch基础中网络参数初始化问题,很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。
成都创新互联专注于二连浩特企业网站建设,成都响应式网站建设公司,商城建设。二连浩特网站建设公司,为二连浩特等地区提供建站服务。全流程定制网站建设,专业设计,全程项目跟踪,成都创新互联专业和态度为您提供的服务
参数访问和遍历:
对于模型参数,我们可以进行访问;
由于Sequential由Module继承而来,所以可以使用Module钟的parameter()或者named_parameters方法来访问所有的参数;
例如,对于使用Sequential搭建的网络,可以使用下列for循环直接进行遍历:
for name, param in net.named_parameters(): print(name, param.size())
当然,也可以使用索引来按层访问,因为本身网络也是按层搭建的:
for name, param in net[0].named_parameters(): print(name, param.size(), type(param))
当我们获取某一层的参数信息后,可以使用data()和grad()函数来进行值和梯度的访问:
weight_0 = list(net[0].parameters())[0] print(weight_0.data) print(weight_0.grad) # 反向传播前梯度为None Y.backward() print(weight_0.grad)
参数初始化问题:
当我们参用for循环获取每层参数,可以采用如下形式对w和偏置b进行初值设定:
for name, param in net.named_parameters(): if 'weight' in name: init.normal_(param, mean=0, std=0.01) print(name, param.data) for name, param in net.named_parameters(): if 'bias' in name: init.constant_(param, val=0) print(name, param.data)
当然,我们也可以进行初始化函数的自定义设置:
def init_weight_(tensor): with torch.no_grad(): tensor.uniform_(-10, 10) tensor *= (tensor.abs() >= 5).float() for name, param in net.named_parameters(): if 'weight' in name: init_weight_(param) print(name, param.data)
这里注意一下torch.no_grad()的问题;
该形式表示该参数并不随着backward进行更改,常常用来进行局部网络参数固定的情况;
如该连接所示:关于no_grad()
共享参数:
可以自定义Module类,在forward中多次调用同一个层实现;
如上章节的代码所示:
class FancyMLP(nn.Module): def __init__(self, **kwargs): super(FancyMLP, self).__init__(**kwargs) self.rand_weight = torch.rand((20, 20), requires_grad=False) # 不可训练参数(常数参数) self.linear = nn.Linear(20, 20) def forward(self, x): x = self.linear(x) # 使用创建的常数参数,以及nn.functional中的relu函数和mm函数 x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1) # 复用全连接层。等价于两个全连接层共享参数 x = self.linear(x) # 控制流,这里我们需要调用item函数来返回标量进行比较 while x.norm().item() > 1: x /= 2 if x.norm().item() < 0.8: x *= 10 return x.sum()
所以可以看到,相当于同时在同一个网络中调用两次相同的Linear实例,所以变相实现了参数共享;
suo'yi注意一下,如果传入Sequential模块的多层都是同一个Module实例的话,则他们共享参数;
看完上述内容是否对您有帮助呢?如果还想对相关知识有进一步的了解或阅读更多相关文章,请关注创新互联行业资讯频道,感谢您对创新互联的支持。
文章名称:如何解析Pytorch基础中网络参数初始化问题
分享地址:http://scyanting.com/article/gsphcp.html