在pytorch中动态调整优化器的学习率方式


Posted in Python onJune 24, 2020

在深度学习中,经常需要动态调整学习率,以达到更好地训练效果,本文纪录在pytorch中的实现方法,其优化器实例为SGD优化器,其他如Adam优化器同样适用。

一般来说,在以SGD优化器作为基本优化器,然后根据epoch实现学习率指数下降,代码如下:

step = [10,20,30,40]
base_lr = 1e-4
sgd_opt = torch.optim.SGD(model.parameters(), lr=base_lr, nesterov=True, momentum=0.9)
def adjust_lr(epoch):
 lr = base_lr * (0.1 ** np.sum(epoch >= np.array(step)))
 for params_group in sgd_opt.param_groups:
  params_group['lr'] = lr
 return lr

只需要在每个train的epoch之前使用这个函数即可。

for epoch in range(60):
 model.train()
 adjust_lr(epoch)
 for ind, each in enumerate(train_loader):
 mat, label = each
 ...

补充知识:Pytorch框架下应用Bi-LSTM实现汽车评论文本关键词抽取

需要调用的模块及整体Bi-lstm流程

import torch
import pandas as pd
import numpy as np
from tensorflow import keras
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import gensim
from sklearn.model_selection import train_test_split
class word_extract(nn.Module):
 def __init__(self,d_model,embedding_matrix):
  super(word_extract, self).__init__()
  self.d_model=d_model
  self.embedding=nn.Embedding(num_embeddings=len(embedding_matrix),embedding_dim=200)
  self.embedding.weight.data.copy_(embedding_matrix)
  self.embedding.weight.requires_grad=False
  self.lstm1=nn.LSTM(input_size=200,hidden_size=50,bidirectional=True)
  self.lstm2=nn.LSTM(input_size=2*self.lstm1.hidden_size,hidden_size=50,bidirectional=True)
  self.linear=nn.Linear(2*self.lstm2.hidden_size,4)

 def forward(self,x):
  w_x=self.embedding(x)
  first_x,(first_h_x,first_c_x)=self.lstm1(w_x)
  second_x,(second_h_x,second_c_x)=self.lstm2(first_x)
  output_x=self.linear(second_x)
  return output_x

将文本转换为数值形式

def trans_num(word2idx,text):
 text_list=[]
 for i in text:
  s=i.rstrip().replace('\r','').replace('\n','').split(' ')
  numtext=[word2idx[j] if j in word2idx.keys() else word2idx['_PAD'] for j in s ]
  text_list.append(numtext)
 return text_list

将Gensim里的词向量模型转为矩阵形式,后续导入到LSTM模型中

def establish_word2vec_matrix(model): #负责将数值索引转为要输入的数据
 word2idx = {"_PAD": 0} # 初始化 `[word : token]` 字典,后期 tokenize 语料库就是用该词典。
 num2idx = {0: "_PAD"}
 vocab_list = [(k, model.wv[k]) for k, v in model.wv.vocab.items()]

 # 存储所有 word2vec 中所有向量的数组,留意其中多一位,词向量全为 0, 用于 padding
 embeddings_matrix = np.zeros((len(model.wv.vocab.items()) + 1, model.vector_size))
 for i in range(len(vocab_list)):
  word = vocab_list[i][0]
  word2idx[word] = i + 1
  num2idx[i + 1] = word
  embeddings_matrix[i + 1] = vocab_list[i][1]
 embeddings_matrix = torch.Tensor(embeddings_matrix)
 return embeddings_matrix, word2idx, num2idx

训练过程

def train(model,epoch,learning_rate,batch_size,x, y, val_x, val_y):
 optimizor = optim.Adam(model.parameters(), lr=learning_rate)
 data = TensorDataset(x, y)
 data = DataLoader(data, batch_size=batch_size)
 for i in range(epoch):
  for j, (per_x, per_y) in enumerate(data):
   output_y = model(per_x)
   loss = F.cross_entropy(output_y.view(-1,output_y.size(2)), per_y.view(-1))
   optimizor.zero_grad()
   loss.backward()
   optimizor.step()
   arg_y=output_y.argmax(dim=2)
   fit_correct=(arg_y==per_y).sum()
   fit_acc=fit_correct.item()/(per_y.size(0)*per_y.size(1))
   print('##################################')
   print('第{}次迭代第{}批次的训练误差为{}'.format(i + 1, j + 1, loss), end=' ')
   print('第{}次迭代第{}批次的训练准确度为{}'.format(i + 1, j + 1, fit_acc))
   val_output_y = model(val_x)
   val_loss = F.cross_entropy(val_output_y.view(-1,val_output_y.size(2)), val_y.view(-1))
   arg_val_y=val_output_y.argmax(dim=2)
   val_correct=(arg_val_y==val_y).sum()
   val_acc=val_correct.item()/(val_y.size(0)*val_y.size(1))
   print('第{}次迭代第{}批次的预测误差为{}'.format(i + 1, j + 1, val_loss), end=' ')
   print('第{}次迭代第{}批次的预测准确度为{}'.format(i + 1, j + 1, val_acc))
 torch.save(model,'./extract_model.pkl')#保存模型

主函数部分

if __name__ =='__main__':
 #生成词向量矩阵
 word2vec = gensim.models.Word2Vec.load('./word2vec_model')
 embedding_matrix,word2idx,num2idx=establish_word2vec_matrix(word2vec)#输入的是词向量模型
 #
 train_data=pd.read_csv('./数据.csv')
 x=list(train_data['文本'])
 # 将文本从文字转化为数值,这部分trans_num函数你需要自己改动去适应你自己的数据集
 x=trans_num(word2idx,x)
 #x需要先进行填充,也就是每个句子都是一样长度,不够长度的以0来填充,填充词单独分为一类
 # #也就是说输入的x是固定长度的数值列表,例如[50,123,1850,21,199,0,0,...]
 #输入的y是[2,0,1,0,0,1,3,3,3,3,3,.....]
 #填充代码你自行编写,以下部分是针对我的数据集
 x=keras.preprocessing.sequence.pad_sequences(
   x,maxlen=60,value=0,padding='post',
 )
 y=list(train_data['BIO数值'])
 y_text=[]
 for i in y:
  s=i.rstrip().split(' ')
  numtext=[int(j) for j in s]
  y_text.append(numtext)
 y=y_text
 y=keras.preprocessing.sequence.pad_sequences(
   y,maxlen=60,value=3,padding='post',
  )
 # 将数据进行划分
 fit_x,val_x,fit_y,val_y=train_test_split(x,y,train_size=0.8,test_size=0.2)
 fit_x=torch.LongTensor(fit_x)
 fit_y=torch.LongTensor(fit_y)
 val_x=torch.LongTensor(val_x)
 val_y=torch.LongTensor(val_y)
 #开始应用
 w_extract=word_extract(d_model=200,embedding_matrix=embedding_matrix)
 train(model=w_extract,epoch=5,learning_rate=0.001,batch_size=50,
   x=fit_x,y=fit_y,val_x=val_x,val_y=val_y)#可以自行改动参数,设置学习率,批次,和迭代次数
 w_extract=torch.load('./extract_model.pkl')#加载保存好的模型
 pred_val_y=w_extract(val_x).argmax(dim=2)

以上这篇在pytorch中动态调整优化器的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python ldap实现登录实例代码
Sep 30 Python
Python实现爬取需要登录的网站完整示例
Aug 19 Python
Python基于回溯法子集树模板实现图的遍历功能示例
Sep 05 Python
Python3计算三角形的面积代码
Dec 18 Python
深入浅析Python中的yield关键字
Jan 24 Python
Python实现绘制双柱状图并显示数值功能示例
Jun 23 Python
python3实现爬取淘宝美食代码分享
Sep 23 Python
python点击鼠标获取坐标(Graphics)
Aug 10 Python
浅谈Python_Openpyxl使用(最全总结)
Sep 05 Python
python 实现简单的FTP程序
Dec 27 Python
pygame实现弹球游戏
Apr 14 Python
Python根据字典的值查询出对应的键的方法
Sep 30 Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
You might like
php 高效率写法 推荐
2010/02/21 PHP
php实现的支持imagemagick及gd库两种处理的缩略图生成类
2014/09/23 PHP
php实现的用户查询类实例
2015/06/18 PHP
Laravel validate error处理,ajax,json示例
2019/10/25 PHP
php+websocket 实现的聊天室功能详解
2020/05/27 PHP
javascript+mapbar实现地图定位
2010/04/09 Javascript
jQuery与ExtJS之选择实例分析
2010/08/19 Javascript
js文件Cookie存取值示例代码
2014/02/20 Javascript
鼠标滑过出现预览的大图提示效果
2014/02/26 Javascript
JavaScript html5 canvas绘制时钟效果(二)
2016/03/27 Javascript
基于jQuery实现瀑布流页面
2017/04/11 jQuery
浅谈angular2 组件的生命周期钩子
2017/08/12 Javascript
实例分析js事件循环机制
2017/12/13 Javascript
Vue2.0子同级组件之间数据交互方法
2018/02/28 Javascript
vue2.0$nextTick监听数据渲染完成之后的回调函数方法
2018/09/11 Javascript
nodejs基础之常用工具模块util用法分析
2018/12/26 NodeJs
node.js中express模块创建服务器和http模块客户端发请求
2019/03/06 Javascript
JavaScript中的 new 命令
2019/05/22 Javascript
vue组件三大核心概念图文详解
2019/05/30 Javascript
OpenLayers3实现地图鹰眼以及地图比例尺的添加
2020/09/25 Javascript
Python设计模式之MVC模式简单示例
2018/01/10 Python
TensorFlow实现创建分类器
2018/02/06 Python
解决Python 使用h5py加载文件,看不到keys()的问题
2019/02/08 Python
python输出带颜色字体实例方法
2019/09/01 Python
python dir函数快速掌握用法技巧
2020/12/09 Python
html5定位并在百度地图上显示的示例
2014/04/27 HTML / CSS
Net-A-Porter美国官网:全球时尚奢侈品名站
2017/02/11 全球购物
英国50岁以上人群的交友网站:Ourtime
2018/03/28 全球购物
公司年会主持词
2014/03/22 职场文书
师范生免费教育协议书范本
2014/10/09 职场文书
二年级语文下册复习计划
2015/01/19 职场文书
观后感开头
2015/06/19 职场文书
小区物业管理2015年度工作总结
2015/10/22 职场文书
趣味运动会标语口号
2015/12/26 职场文书
2019最新版劳务派遣管理制度
2019/08/16 职场文书
在HTML5 localStorage中存储对象的示例代码
2021/04/21 Javascript