在pytorch中动态调整优化器的学习率方式


Posted in Python onJune 24, 2020

在深度学习中,经常需要动态调整学习率,以达到更好地训练效果,本文纪录在pytorch中的实现方法,其优化器实例为SGD优化器,其他如Adam优化器同样适用。

一般来说,在以SGD优化器作为基本优化器,然后根据epoch实现学习率指数下降,代码如下:

step = [10,20,30,40]
base_lr = 1e-4
sgd_opt = torch.optim.SGD(model.parameters(), lr=base_lr, nesterov=True, momentum=0.9)
def adjust_lr(epoch):
 lr = base_lr * (0.1 ** np.sum(epoch >= np.array(step)))
 for params_group in sgd_opt.param_groups:
  params_group['lr'] = lr
 return lr

只需要在每个train的epoch之前使用这个函数即可。

for epoch in range(60):
 model.train()
 adjust_lr(epoch)
 for ind, each in enumerate(train_loader):
 mat, label = each
 ...

补充知识:Pytorch框架下应用Bi-LSTM实现汽车评论文本关键词抽取

需要调用的模块及整体Bi-lstm流程

import torch
import pandas as pd
import numpy as np
from tensorflow import keras
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import gensim
from sklearn.model_selection import train_test_split
class word_extract(nn.Module):
 def __init__(self,d_model,embedding_matrix):
  super(word_extract, self).__init__()
  self.d_model=d_model
  self.embedding=nn.Embedding(num_embeddings=len(embedding_matrix),embedding_dim=200)
  self.embedding.weight.data.copy_(embedding_matrix)
  self.embedding.weight.requires_grad=False
  self.lstm1=nn.LSTM(input_size=200,hidden_size=50,bidirectional=True)
  self.lstm2=nn.LSTM(input_size=2*self.lstm1.hidden_size,hidden_size=50,bidirectional=True)
  self.linear=nn.Linear(2*self.lstm2.hidden_size,4)

 def forward(self,x):
  w_x=self.embedding(x)
  first_x,(first_h_x,first_c_x)=self.lstm1(w_x)
  second_x,(second_h_x,second_c_x)=self.lstm2(first_x)
  output_x=self.linear(second_x)
  return output_x

将文本转换为数值形式

def trans_num(word2idx,text):
 text_list=[]
 for i in text:
  s=i.rstrip().replace('\r','').replace('\n','').split(' ')
  numtext=[word2idx[j] if j in word2idx.keys() else word2idx['_PAD'] for j in s ]
  text_list.append(numtext)
 return text_list

将Gensim里的词向量模型转为矩阵形式,后续导入到LSTM模型中

def establish_word2vec_matrix(model): #负责将数值索引转为要输入的数据
 word2idx = {"_PAD": 0} # 初始化 `[word : token]` 字典,后期 tokenize 语料库就是用该词典。
 num2idx = {0: "_PAD"}
 vocab_list = [(k, model.wv[k]) for k, v in model.wv.vocab.items()]

 # 存储所有 word2vec 中所有向量的数组,留意其中多一位,词向量全为 0, 用于 padding
 embeddings_matrix = np.zeros((len(model.wv.vocab.items()) + 1, model.vector_size))
 for i in range(len(vocab_list)):
  word = vocab_list[i][0]
  word2idx[word] = i + 1
  num2idx[i + 1] = word
  embeddings_matrix[i + 1] = vocab_list[i][1]
 embeddings_matrix = torch.Tensor(embeddings_matrix)
 return embeddings_matrix, word2idx, num2idx

训练过程

def train(model,epoch,learning_rate,batch_size,x, y, val_x, val_y):
 optimizor = optim.Adam(model.parameters(), lr=learning_rate)
 data = TensorDataset(x, y)
 data = DataLoader(data, batch_size=batch_size)
 for i in range(epoch):
  for j, (per_x, per_y) in enumerate(data):
   output_y = model(per_x)
   loss = F.cross_entropy(output_y.view(-1,output_y.size(2)), per_y.view(-1))
   optimizor.zero_grad()
   loss.backward()
   optimizor.step()
   arg_y=output_y.argmax(dim=2)
   fit_correct=(arg_y==per_y).sum()
   fit_acc=fit_correct.item()/(per_y.size(0)*per_y.size(1))
   print('##################################')
   print('第{}次迭代第{}批次的训练误差为{}'.format(i + 1, j + 1, loss), end=' ')
   print('第{}次迭代第{}批次的训练准确度为{}'.format(i + 1, j + 1, fit_acc))
   val_output_y = model(val_x)
   val_loss = F.cross_entropy(val_output_y.view(-1,val_output_y.size(2)), val_y.view(-1))
   arg_val_y=val_output_y.argmax(dim=2)
   val_correct=(arg_val_y==val_y).sum()
   val_acc=val_correct.item()/(val_y.size(0)*val_y.size(1))
   print('第{}次迭代第{}批次的预测误差为{}'.format(i + 1, j + 1, val_loss), end=' ')
   print('第{}次迭代第{}批次的预测准确度为{}'.format(i + 1, j + 1, val_acc))
 torch.save(model,'./extract_model.pkl')#保存模型

主函数部分

if __name__ =='__main__':
 #生成词向量矩阵
 word2vec = gensim.models.Word2Vec.load('./word2vec_model')
 embedding_matrix,word2idx,num2idx=establish_word2vec_matrix(word2vec)#输入的是词向量模型
 #
 train_data=pd.read_csv('./数据.csv')
 x=list(train_data['文本'])
 # 将文本从文字转化为数值,这部分trans_num函数你需要自己改动去适应你自己的数据集
 x=trans_num(word2idx,x)
 #x需要先进行填充,也就是每个句子都是一样长度,不够长度的以0来填充,填充词单独分为一类
 # #也就是说输入的x是固定长度的数值列表,例如[50,123,1850,21,199,0,0,...]
 #输入的y是[2,0,1,0,0,1,3,3,3,3,3,.....]
 #填充代码你自行编写,以下部分是针对我的数据集
 x=keras.preprocessing.sequence.pad_sequences(
   x,maxlen=60,value=0,padding='post',
 )
 y=list(train_data['BIO数值'])
 y_text=[]
 for i in y:
  s=i.rstrip().split(' ')
  numtext=[int(j) for j in s]
  y_text.append(numtext)
 y=y_text
 y=keras.preprocessing.sequence.pad_sequences(
   y,maxlen=60,value=3,padding='post',
  )
 # 将数据进行划分
 fit_x,val_x,fit_y,val_y=train_test_split(x,y,train_size=0.8,test_size=0.2)
 fit_x=torch.LongTensor(fit_x)
 fit_y=torch.LongTensor(fit_y)
 val_x=torch.LongTensor(val_x)
 val_y=torch.LongTensor(val_y)
 #开始应用
 w_extract=word_extract(d_model=200,embedding_matrix=embedding_matrix)
 train(model=w_extract,epoch=5,learning_rate=0.001,batch_size=50,
   x=fit_x,y=fit_y,val_x=val_x,val_y=val_y)#可以自行改动参数,设置学习率,批次,和迭代次数
 w_extract=torch.load('./extract_model.pkl')#加载保存好的模型
 pred_val_y=w_extract(val_x).argmax(dim=2)

以上这篇在pytorch中动态调整优化器的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中一些自然语言工具的使用的入门教程
Apr 13 Python
python正则表达式之作业计算器
Mar 18 Python
wxpython中自定义事件的实现与使用方法分析
Jul 21 Python
python获取外网IP并发邮件的实现方法
Oct 01 Python
Python+matplotlib实现填充螺旋实例
Jan 15 Python
Python paramiko模块的使用示例
Apr 11 Python
mac PyCharm添加Python解释器及添加package路径的方法
Oct 29 Python
OpenCV 边缘检测
Jul 10 Python
Python用input输入列表的实例代码
Feb 07 Python
Python读取xlsx数据生成图标代码实例
Aug 12 Python
python3 sqlite3限制条件查询的操作
Apr 07 Python
详解Python中__new__方法的作用
Mar 31 Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
You might like
php使用fsockopen函数发送post,get请求获取网页内容的方法
2014/11/15 PHP
php实现paypal 授权登录
2015/05/28 PHP
Thinkphp5 如何隐藏入口文件index.php(URL重写)
2019/10/16 PHP
PHP数组基本用法与知识点总结
2020/06/02 PHP
IE6-8中Date不支持toISOString的修复方法
2014/05/04 Javascript
完美实现仿QQ空间评论回复特效
2015/05/06 Javascript
使用javaScript动态加载Js文件和Css文件
2015/10/24 Javascript
Vue.directive自定义指令的使用详解
2017/03/10 Javascript
jQuery实现页面倒计时并刷新效果
2017/03/13 Javascript
js 实现复选框只能选择一项的示例代码
2018/01/23 Javascript
JS删除数组里的某个元素方法
2018/02/03 Javascript
vue-router相关基础知识及工作原理
2018/03/16 Javascript
Mac下通过brew安装指定版本的nodejs教程
2018/05/17 NodeJs
vue项目部署到Apache服务器中遇到的问题解决
2018/08/24 Javascript
详解nvm管理多版本node踩坑
2019/07/26 Javascript
taro小程序添加骨架屏的实现代码
2019/11/15 Javascript
BootStrap前端框架使用方法详解
2020/02/26 Javascript
浅谈在vue-cli3项目中解决动态引入图片img404的问题
2020/08/04 Javascript
Python Trie树实现字典排序
2014/03/28 Python
Python3基础之list列表实例解析
2014/08/13 Python
在Python中进行自动化单元测试的教程
2015/04/15 Python
Django实现图片文字同时提交的方法
2015/05/26 Python
Python元组操作实例分析【创建、赋值、更新、删除等】
2017/07/24 Python
python+tkinter编写电脑桌面放大镜程序实例代码
2018/01/16 Python
python3获取两个日期之间所有日期,以及比较大小的实例
2018/04/08 Python
Omio美国:全欧洲低价大巴、火车和航班搜索和比价
2017/11/08 全球购物
MATCHESFASHION.COM法国官网:英国奢侈品零售商
2018/01/04 全球购物
Michael Kors加拿大官网:购买设计师手袋、手表、鞋子、服装等
2019/03/16 全球购物
请写出char *p与"零值"比较的if语句
2014/09/24 面试题
读书心得体会
2013/12/28 职场文书
教师师德反思材料
2014/02/15 职场文书
质量负责人任命书
2014/06/06 职场文书
艾滋病宣传标语
2014/06/25 职场文书
遗失证明范文
2015/06/19 职场文书
六个好看实用的 HTML + CSS 后台登录入口页面
2022/04/28 HTML / CSS
SpringBoot前端后端分离之Nginx服务器下载安装过程
2022/08/14 Servers