在pytorch中动态调整优化器的学习率方式


Posted in Python onJune 24, 2020

在深度学习中,经常需要动态调整学习率,以达到更好地训练效果,本文纪录在pytorch中的实现方法,其优化器实例为SGD优化器,其他如Adam优化器同样适用。

一般来说,在以SGD优化器作为基本优化器,然后根据epoch实现学习率指数下降,代码如下:

step = [10,20,30,40]
base_lr = 1e-4
sgd_opt = torch.optim.SGD(model.parameters(), lr=base_lr, nesterov=True, momentum=0.9)
def adjust_lr(epoch):
 lr = base_lr * (0.1 ** np.sum(epoch >= np.array(step)))
 for params_group in sgd_opt.param_groups:
  params_group['lr'] = lr
 return lr

只需要在每个train的epoch之前使用这个函数即可。

for epoch in range(60):
 model.train()
 adjust_lr(epoch)
 for ind, each in enumerate(train_loader):
 mat, label = each
 ...

补充知识:Pytorch框架下应用Bi-LSTM实现汽车评论文本关键词抽取

需要调用的模块及整体Bi-lstm流程

import torch
import pandas as pd
import numpy as np
from tensorflow import keras
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import gensim
from sklearn.model_selection import train_test_split
class word_extract(nn.Module):
 def __init__(self,d_model,embedding_matrix):
  super(word_extract, self).__init__()
  self.d_model=d_model
  self.embedding=nn.Embedding(num_embeddings=len(embedding_matrix),embedding_dim=200)
  self.embedding.weight.data.copy_(embedding_matrix)
  self.embedding.weight.requires_grad=False
  self.lstm1=nn.LSTM(input_size=200,hidden_size=50,bidirectional=True)
  self.lstm2=nn.LSTM(input_size=2*self.lstm1.hidden_size,hidden_size=50,bidirectional=True)
  self.linear=nn.Linear(2*self.lstm2.hidden_size,4)

 def forward(self,x):
  w_x=self.embedding(x)
  first_x,(first_h_x,first_c_x)=self.lstm1(w_x)
  second_x,(second_h_x,second_c_x)=self.lstm2(first_x)
  output_x=self.linear(second_x)
  return output_x

将文本转换为数值形式

def trans_num(word2idx,text):
 text_list=[]
 for i in text:
  s=i.rstrip().replace('\r','').replace('\n','').split(' ')
  numtext=[word2idx[j] if j in word2idx.keys() else word2idx['_PAD'] for j in s ]
  text_list.append(numtext)
 return text_list

将Gensim里的词向量模型转为矩阵形式,后续导入到LSTM模型中

def establish_word2vec_matrix(model): #负责将数值索引转为要输入的数据
 word2idx = {"_PAD": 0} # 初始化 `[word : token]` 字典,后期 tokenize 语料库就是用该词典。
 num2idx = {0: "_PAD"}
 vocab_list = [(k, model.wv[k]) for k, v in model.wv.vocab.items()]

 # 存储所有 word2vec 中所有向量的数组,留意其中多一位,词向量全为 0, 用于 padding
 embeddings_matrix = np.zeros((len(model.wv.vocab.items()) + 1, model.vector_size))
 for i in range(len(vocab_list)):
  word = vocab_list[i][0]
  word2idx[word] = i + 1
  num2idx[i + 1] = word
  embeddings_matrix[i + 1] = vocab_list[i][1]
 embeddings_matrix = torch.Tensor(embeddings_matrix)
 return embeddings_matrix, word2idx, num2idx

训练过程

def train(model,epoch,learning_rate,batch_size,x, y, val_x, val_y):
 optimizor = optim.Adam(model.parameters(), lr=learning_rate)
 data = TensorDataset(x, y)
 data = DataLoader(data, batch_size=batch_size)
 for i in range(epoch):
  for j, (per_x, per_y) in enumerate(data):
   output_y = model(per_x)
   loss = F.cross_entropy(output_y.view(-1,output_y.size(2)), per_y.view(-1))
   optimizor.zero_grad()
   loss.backward()
   optimizor.step()
   arg_y=output_y.argmax(dim=2)
   fit_correct=(arg_y==per_y).sum()
   fit_acc=fit_correct.item()/(per_y.size(0)*per_y.size(1))
   print('##################################')
   print('第{}次迭代第{}批次的训练误差为{}'.format(i + 1, j + 1, loss), end=' ')
   print('第{}次迭代第{}批次的训练准确度为{}'.format(i + 1, j + 1, fit_acc))
   val_output_y = model(val_x)
   val_loss = F.cross_entropy(val_output_y.view(-1,val_output_y.size(2)), val_y.view(-1))
   arg_val_y=val_output_y.argmax(dim=2)
   val_correct=(arg_val_y==val_y).sum()
   val_acc=val_correct.item()/(val_y.size(0)*val_y.size(1))
   print('第{}次迭代第{}批次的预测误差为{}'.format(i + 1, j + 1, val_loss), end=' ')
   print('第{}次迭代第{}批次的预测准确度为{}'.format(i + 1, j + 1, val_acc))
 torch.save(model,'./extract_model.pkl')#保存模型

主函数部分

if __name__ =='__main__':
 #生成词向量矩阵
 word2vec = gensim.models.Word2Vec.load('./word2vec_model')
 embedding_matrix,word2idx,num2idx=establish_word2vec_matrix(word2vec)#输入的是词向量模型
 #
 train_data=pd.read_csv('./数据.csv')
 x=list(train_data['文本'])
 # 将文本从文字转化为数值,这部分trans_num函数你需要自己改动去适应你自己的数据集
 x=trans_num(word2idx,x)
 #x需要先进行填充,也就是每个句子都是一样长度,不够长度的以0来填充,填充词单独分为一类
 # #也就是说输入的x是固定长度的数值列表,例如[50,123,1850,21,199,0,0,...]
 #输入的y是[2,0,1,0,0,1,3,3,3,3,3,.....]
 #填充代码你自行编写,以下部分是针对我的数据集
 x=keras.preprocessing.sequence.pad_sequences(
   x,maxlen=60,value=0,padding='post',
 )
 y=list(train_data['BIO数值'])
 y_text=[]
 for i in y:
  s=i.rstrip().split(' ')
  numtext=[int(j) for j in s]
  y_text.append(numtext)
 y=y_text
 y=keras.preprocessing.sequence.pad_sequences(
   y,maxlen=60,value=3,padding='post',
  )
 # 将数据进行划分
 fit_x,val_x,fit_y,val_y=train_test_split(x,y,train_size=0.8,test_size=0.2)
 fit_x=torch.LongTensor(fit_x)
 fit_y=torch.LongTensor(fit_y)
 val_x=torch.LongTensor(val_x)
 val_y=torch.LongTensor(val_y)
 #开始应用
 w_extract=word_extract(d_model=200,embedding_matrix=embedding_matrix)
 train(model=w_extract,epoch=5,learning_rate=0.001,batch_size=50,
   x=fit_x,y=fit_y,val_x=val_x,val_y=val_y)#可以自行改动参数,设置学习率,批次,和迭代次数
 w_extract=torch.load('./extract_model.pkl')#加载保存好的模型
 pred_val_y=w_extract(val_x).argmax(dim=2)

以上这篇在pytorch中动态调整优化器的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python定时检查启动某个exe程序适合检测exe是否挂了
Jan 21 Python
Python实现二分法算法实例
Feb 02 Python
Python中的ConfigParser模块使用详解
May 04 Python
详解Python中expandtabs()方法的使用
May 18 Python
Python3 读、写Excel文件的操作方法
Oct 20 Python
python三方库之requests的快速上手
Mar 04 Python
python在openstreetmap地图上绘制路线图的实现
Jul 11 Python
下载官网python并安装的步骤详解
Oct 12 Python
python实现门限回归方式
Feb 29 Python
Python实现打包成库供别的模块调用
Jul 13 Python
详解python UDP 编程
Aug 24 Python
Python内置类型集合set和frozenset的使用详解
Apr 26 Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
You might like
php设计模式 Interpreter(解释器模式)
2011/06/26 PHP
PHP实现图片上传并压缩
2015/12/22 PHP
jquery tools之tooltip
2009/07/25 Javascript
跟我一起学写jQuery插件开发方法(附完整实例及下载)
2010/04/01 Javascript
与jquery serializeArray()一起使用的函数,主要来方便提交表单
2011/01/31 Javascript
JQuery实现鼠标滑过显示导航下拉列表
2013/09/12 Javascript
js字符串截取函数substr substring slice使用对比
2013/11/27 Javascript
基于NodeJS的前后端分离的思考与实践(二)模版探索
2014/09/26 NodeJs
JavaScript实现的简单拖拽效果
2015/06/01 Javascript
js简单实现Select互换数据的方法
2015/08/17 Javascript
AngularJs 国际化(I18n/L10n)详解
2016/09/01 Javascript
Html中 IFrame的用法及注意点
2016/12/22 Javascript
JavaScript获取中英文混合字符串长度的方法示例
2017/02/04 Javascript
神级程序员JavaScript300行代码搞定汉字转拼音
2017/05/20 Javascript
MUI实现上拉加载和下拉刷新效果
2017/06/30 Javascript
JS中定位 position 的使用实例代码
2017/08/06 Javascript
JavaScript中递归实现的方法及其区别
2017/09/12 Javascript
详解如何使用webpack打包多页jquery项目
2019/02/01 jQuery
手把手带你封装一个vue component第三方库
2019/02/14 Javascript
Android 自定义view仿微信相机单击拍照长按录视频按钮
2019/07/19 Javascript
vue2路由方式--嵌套路由实现方法分析
2020/03/06 Javascript
three.js 如何制作魔方
2020/07/31 Javascript
python3.6连接MySQL和表的创建与删除实例代码
2017/12/28 Python
利用Python如何将数据写到CSV文件中
2018/06/05 Python
浅谈Django2.0 加xadmin踩的坑
2019/11/15 Python
PyTorch中反卷积的用法详解
2019/12/30 Python
PyQt5事件处理之定时在控件上显示信息的代码
2020/03/25 Python
python中把元组转换为namedtuple方法
2020/12/09 Python
简单总结CSS3中视窗单位Viewport的常见用法
2016/02/04 HTML / CSS
Kipling意大利官网:世界著名的时尚休闲包袋品牌
2019/06/05 全球购物
澳大利亚在线购买葡萄酒:The Wine Collective
2020/02/20 全球购物
中学生寄语大全
2014/04/03 职场文书
《去年的树》教学反思
2016/02/18 职场文书
2019年怎样才能撰写出优秀的自荐信
2019/03/25 职场文书
《鲁班学艺》读后感3篇
2019/11/27 职场文书
mysql的单列多值存储实例详解
2022/04/05 MySQL