搜索

耀世资讯

公司动态
行业新闻

联系我们

Contact us

电话:400-123-4567
Q Q:1234567890
邮箱:admin@youweb.com
地址:广东省广州市天河区88号

Pytorch框架学习---(4)优化器Optimizer

发布时间:2024-06-18 21:29:30 作者:佚名
<table><tbody><tr><td bgcolor="F5F5DC"> <p>本节讲述Pytorch中torch.optim优化器包,学习率、参数Momentum动量的含义,以及常用的几类优化器。【Latex公式采用<a href="https://www.latexlive.com/" target="_blank" rel="noopener">在线编码器</a>】</p> <p>优化器概念:管理并更新模型所选中的网络参数,使得模型输出更加接近真实标签。</p> </td></tr></tbody></table> <p></p><div class="toc"><div class="toc-container-header">目录</div><ul><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#1-optimizer基本属性" rel="noopener">1. Optimizer基本属性</a><ul><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#1如何创建一个优化器" rel="noopener">(1)如何创建一个优化器</a></li><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#2继承optimizer父类" rel="noopener">(2)继承Optimizer父类</a></li></ul></li><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#2optimizer的基本方法" rel="noopener">2.Optimizer的基本方法</a><ul><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#1optimizerzero_grad" rel="noopener">(1)optimizer.zero_grad()</a></li><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#2optimizerstep" rel="noopener">(2)optimizer.step()</a></li><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#3optimizeradd_param_group" rel="noopener">(3)optimizer.add_param_group()</a></li><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#4optimizerstate_dict" rel="noopener">(4)optimizer.state_dict()</a></li><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#5optimizerload_state_dict" rel="noopener">(5)optimizer.load_state_dict()</a></li></ul></li><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#3学习率lr" rel="noopener">3.学习率<em>lr</em></a></li><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#4动量momentum" rel="noopener">4.动量<em>Momentum</em></a><ul><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#1指数加权平均" rel="noopener">(1)指数加权平均</a></li><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#2pytroch中的动量计算" rel="noopener">(2)Pytroch中的动量计算</a></li></ul></li><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#5optimsgd随机梯度下降" rel="noopener">5.optim.SGD随机梯度下降</a></li><li><a href="https://www.cnblogs.com/zpc1001/p/13195928.html#6torchoptim下10种优化器" rel="noopener">6.torch.optim下10种优化器</a></li></ul></div><p></p> <pre></pre> <p>? 所有的optim中的优化器都继承Optimizer父类,即:</p> <pre></pre> <p>? 由上式代码<strong><font color="red">注释#</font></strong>可知,重要参数如下:</p> <ul> <li> <p>self.defaults:优化器本身参数,如学习率、动量等等</p> </li> <li> <p>self.state:参数缓存,如动量缓存</p> </li> <li> <p>self.param_groups:管理的参数组,注意这里是list(dict)形式,即列表中字典。</p> </li> </ul> <p>? ? 例如:<class 'list'>: [{'params': [网络参数], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]</p> <p>? <strong><font color="red">注意:</font></strong>这里模型中的参数(如W)与param_groups中保存的W,地址相同。</p> <p>? 清空所管理的网络参数的梯度</p> <pre></pre> <p>? 执行一步更新,根据对应的梯度下降策略。</p> <p>? 添加参数组,<strong><font color="red">经常用于finetune,又例如设置两部分参数</font></strong>,e.g. 网络分为:特征提取层+全连接分类层,设置两组优化参数。</p> <pre></pre> <p>? 同一个优化器,添加新的优化参数:</p> <pre></pre> <p>? 可以看到添加之后,optimizer.param_groups list中含有两个字典,一个字典是之前的参数,另一个字典是新添加的一系列优化器参数</p> <p>? 获取当前优化器的一系列信息参数。由代码可知,返回的是字典,两个key:'state'和'param_groups'</p> <pre></pre> <p>? self.state:参数缓存,如动量缓存,当网络没有经过optimizer.step(),即没有根据loss.backward()得到的梯度去更新网络参数时,state为空:</p> <pre></pre> <p>当更新之后,'state'将保存'params'中value的地址以及{'momentun_buffer':tensor()}动量缓存,<strong><font color="red">用于后续断点恢复</font></strong>。</p> <p>? 加载保存的状态信息字典</p> <pre></pre> <p></p><div class="math display">\[w_{i+1}=w_{i} - lr\ast grad\left ( w_{i} \right ) \]</div><p></p><p>? 学习率可以看作是对梯度的缩小因子,用来控制梯度更新的步伐:</p> <ul> <li> <p>lr不能过大(易loss激增);</p> </li> <li> <p>lr不能过小(收敛较慢);</p> </li> <li> <p>当设置lr适当小时,如0.01,此时可通过增加网络训练时间,进行弥补;</p> </li> </ul> <p>? 结合当前梯度与上一时刻更新的信息,来更新当前梯度信息。Momentum 梯度下降法 可追溯到指数加权平均:</p> <p></p><div class="math display">\[V_{t}=\beta \cdot V_{t-1} + \left ( 1- \beta \right ) \cdot heta _{t}=\sum_{i=0}^{t}\left ( \left ( 1- \beta \right ) \cdot \beta^{i} \cdot heta _{t-i} \right ) \]</div><p></p><p>其中 <span class="math inline">\( heta _{t}\)</span> 为当前时刻的参数,因为 <span class="math inline">\(\beta < 1\)</span> ,从上述公式可知,距离当前t时刻越远的时刻参数,权重越小,对t时刻影响越小。</p> <p></p><div class="math display">\[\left\{\begin{matrix} V_{i}=m\cdot V_{i-1}+ grad\left ( w_{i} \right ) \\w_{i+1}=w_{i} - lr\cdot V_{i} \end{matrix}\right.\]</div><p></p><p></p><div class="math display">\[\Longrightarrow V_{i}=\sum_{j=0}^{i} \left ( m^{j} \cdot grad\left ( w_{i-j} \right ) \right ) \]</div><p></p><p>可以看到, 当 <span class="math inline">\(Momentum\)</span> 太大时,由于受到前面时刻梯度线性影响,会有一定的震荡。</p> <pre></pre> <p>? 下次来补充啦!</p>
热线电话:400-123-4567
电子邮箱:admin@youweb.com
Q Q:1234567890
地址:广东省广州市天河区88号
备案号:
耀世娱乐-耀世平台-耀世加盟站

关注我们

Copyright © 2002-2017 耀世-耀世平台-耀世加盟站 版权所有

平台注册入口