在pytorch中动态调整优化器的学习率方式


Posted in Python onJune 24, 2020

在深度学习中,经常需要动态调整学习率,以达到更好地训练效果,本文纪录在pytorch中的实现方法,其优化器实例为SGD优化器,其他如Adam优化器同样适用。

一般来说,在以SGD优化器作为基本优化器,然后根据epoch实现学习率指数下降,代码如下:

step = [10,20,30,40]
base_lr = 1e-4
sgd_opt = torch.optim.SGD(model.parameters(), lr=base_lr, nesterov=True, momentum=0.9)
def adjust_lr(epoch):
 lr = base_lr * (0.1 ** np.sum(epoch >= np.array(step)))
 for params_group in sgd_opt.param_groups:
  params_group['lr'] = lr
 return lr

只需要在每个train的epoch之前使用这个函数即可。

for epoch in range(60):
 model.train()
 adjust_lr(epoch)
 for ind, each in enumerate(train_loader):
 mat, label = each
 ...

补充知识:Pytorch框架下应用Bi-LSTM实现汽车评论文本关键词抽取

需要调用的模块及整体Bi-lstm流程

import torch
import pandas as pd
import numpy as np
from tensorflow import keras
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import gensim
from sklearn.model_selection import train_test_split
class word_extract(nn.Module):
 def __init__(self,d_model,embedding_matrix):
  super(word_extract, self).__init__()
  self.d_model=d_model
  self.embedding=nn.Embedding(num_embeddings=len(embedding_matrix),embedding_dim=200)
  self.embedding.weight.data.copy_(embedding_matrix)
  self.embedding.weight.requires_grad=False
  self.lstm1=nn.LSTM(input_size=200,hidden_size=50,bidirectional=True)
  self.lstm2=nn.LSTM(input_size=2*self.lstm1.hidden_size,hidden_size=50,bidirectional=True)
  self.linear=nn.Linear(2*self.lstm2.hidden_size,4)

 def forward(self,x):
  w_x=self.embedding(x)
  first_x,(first_h_x,first_c_x)=self.lstm1(w_x)
  second_x,(second_h_x,second_c_x)=self.lstm2(first_x)
  output_x=self.linear(second_x)
  return output_x

将文本转换为数值形式

def trans_num(word2idx,text):
 text_list=[]
 for i in text:
  s=i.rstrip().replace('\r','').replace('\n','').split(' ')
  numtext=[word2idx[j] if j in word2idx.keys() else word2idx['_PAD'] for j in s ]
  text_list.append(numtext)
 return text_list

将Gensim里的词向量模型转为矩阵形式,后续导入到LSTM模型中

def establish_word2vec_matrix(model): #负责将数值索引转为要输入的数据
 word2idx = {"_PAD": 0} # 初始化 `[word : token]` 字典,后期 tokenize 语料库就是用该词典。
 num2idx = {0: "_PAD"}
 vocab_list = [(k, model.wv[k]) for k, v in model.wv.vocab.items()]

 # 存储所有 word2vec 中所有向量的数组,留意其中多一位,词向量全为 0, 用于 padding
 embeddings_matrix = np.zeros((len(model.wv.vocab.items()) + 1, model.vector_size))
 for i in range(len(vocab_list)):
  word = vocab_list[i][0]
  word2idx[word] = i + 1
  num2idx[i + 1] = word
  embeddings_matrix[i + 1] = vocab_list[i][1]
 embeddings_matrix = torch.Tensor(embeddings_matrix)
 return embeddings_matrix, word2idx, num2idx

训练过程

def train(model,epoch,learning_rate,batch_size,x, y, val_x, val_y):
 optimizor = optim.Adam(model.parameters(), lr=learning_rate)
 data = TensorDataset(x, y)
 data = DataLoader(data, batch_size=batch_size)
 for i in range(epoch):
  for j, (per_x, per_y) in enumerate(data):
   output_y = model(per_x)
   loss = F.cross_entropy(output_y.view(-1,output_y.size(2)), per_y.view(-1))
   optimizor.zero_grad()
   loss.backward()
   optimizor.step()
   arg_y=output_y.argmax(dim=2)
   fit_correct=(arg_y==per_y).sum()
   fit_acc=fit_correct.item()/(per_y.size(0)*per_y.size(1))
   print('##################################')
   print('第{}次迭代第{}批次的训练误差为{}'.format(i + 1, j + 1, loss), end=' ')
   print('第{}次迭代第{}批次的训练准确度为{}'.format(i + 1, j + 1, fit_acc))
   val_output_y = model(val_x)
   val_loss = F.cross_entropy(val_output_y.view(-1,val_output_y.size(2)), val_y.view(-1))
   arg_val_y=val_output_y.argmax(dim=2)
   val_correct=(arg_val_y==val_y).sum()
   val_acc=val_correct.item()/(val_y.size(0)*val_y.size(1))
   print('第{}次迭代第{}批次的预测误差为{}'.format(i + 1, j + 1, val_loss), end=' ')
   print('第{}次迭代第{}批次的预测准确度为{}'.format(i + 1, j + 1, val_acc))
 torch.save(model,'./extract_model.pkl')#保存模型

主函数部分

if __name__ =='__main__':
 #生成词向量矩阵
 word2vec = gensim.models.Word2Vec.load('./word2vec_model')
 embedding_matrix,word2idx,num2idx=establish_word2vec_matrix(word2vec)#输入的是词向量模型
 #
 train_data=pd.read_csv('./数据.csv')
 x=list(train_data['文本'])
 # 将文本从文字转化为数值,这部分trans_num函数你需要自己改动去适应你自己的数据集
 x=trans_num(word2idx,x)
 #x需要先进行填充,也就是每个句子都是一样长度,不够长度的以0来填充,填充词单独分为一类
 # #也就是说输入的x是固定长度的数值列表,例如[50,123,1850,21,199,0,0,...]
 #输入的y是[2,0,1,0,0,1,3,3,3,3,3,.....]
 #填充代码你自行编写,以下部分是针对我的数据集
 x=keras.preprocessing.sequence.pad_sequences(
   x,maxlen=60,value=0,padding='post',
 )
 y=list(train_data['BIO数值'])
 y_text=[]
 for i in y:
  s=i.rstrip().split(' ')
  numtext=[int(j) for j in s]
  y_text.append(numtext)
 y=y_text
 y=keras.preprocessing.sequence.pad_sequences(
   y,maxlen=60,value=3,padding='post',
  )
 # 将数据进行划分
 fit_x,val_x,fit_y,val_y=train_test_split(x,y,train_size=0.8,test_size=0.2)
 fit_x=torch.LongTensor(fit_x)
 fit_y=torch.LongTensor(fit_y)
 val_x=torch.LongTensor(val_x)
 val_y=torch.LongTensor(val_y)
 #开始应用
 w_extract=word_extract(d_model=200,embedding_matrix=embedding_matrix)
 train(model=w_extract,epoch=5,learning_rate=0.001,batch_size=50,
   x=fit_x,y=fit_y,val_x=val_x,val_y=val_y)#可以自行改动参数,设置学习率,批次,和迭代次数
 w_extract=torch.load('./extract_model.pkl')#加载保存好的模型
 pred_val_y=w_extract(val_x).argmax(dim=2)

以上这篇在pytorch中动态调整优化器的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用multiprocessing模块实现带回调函数的异步调用方法
Apr 18 Python
在Python中使用mechanize模块模拟浏览器功能
May 05 Python
python实现手机通讯录搜索功能
Feb 22 Python
使用Python微信库itchat获得好友和群组已撤回的消息
Jun 24 Python
Python3实现的判断环形链表算法示例
Mar 07 Python
Python3.5内置模块之shelve模块、xml模块、configparser模块、hashlib、hmac模块用法分析
Apr 27 Python
django框架模板中定义变量(set variable in django template)的方法分析
Jun 24 Python
Python 读取串口数据,动态绘图的示例
Jul 02 Python
解决python web项目意外关闭,但占用端口的问题
Dec 17 Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 Python
python线性插值解析
Jul 05 Python
python 实现Requests发送带cookies的请求
Feb 08 Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
You might like
神族 PROTOSS 概述
2020/03/14 星际争霸
用PHP写的一个冒泡排序法的函数简单实例
2016/05/26 PHP
JavaScript 异步调用框架 (Part 6 - 实例 & 模式)
2009/08/04 Javascript
javascript 从if else 到 switch case 再到抽象
2010/07/17 Javascript
前端轻量级MVC框架CanJS详解
2014/09/26 Javascript
jquery插件jquery.nicescroll实现图片无滚动条左右拖拽的方法
2015/08/10 Javascript
七个不允许错过的jQuery小技巧
2015/12/21 Javascript
javascript实现checkbox复选框实例代码
2016/01/10 Javascript
Node.js中JavaScript操作MySQL的常用方法整理
2016/03/01 Javascript
Bootstrap页面布局基础知识全面解析
2016/06/13 Javascript
AngularJS 实现JavaScript 动画效果详解
2016/09/08 Javascript
jQuery弹出窗口简单实现代码
2017/03/09 Javascript
JS对象深度克隆实例分析
2017/03/16 Javascript
jquery将标签元素的高设为屏幕的百分比
2017/04/19 jQuery
深入理解Commonjs规范及Node模块实现
2017/05/17 Javascript
微信小程序 五星评分的实现实例
2017/08/04 Javascript
Node.js搭建小程序后台服务
2018/01/03 Javascript
在Vue组件中使用 TypeScript的方法
2018/02/28 Javascript
原生js代码能实现call和bind吗
2019/07/31 Javascript
JavaScript的console命令使用实例
2019/12/03 Javascript
关于ES6尾调用优化的使用
2020/09/11 Javascript
Python函数中的函数(闭包)用法实例
2016/03/15 Python
Python3操作SQL Server数据库(实例讲解)
2017/10/21 Python
python爬虫之urllib3的使用示例
2018/07/09 Python
对python opencv 添加文字 cv2.putText 的各参数介绍
2018/12/05 Python
python2.7 安装pip的方法步骤(管用)
2019/05/05 Python
Python3日期与时间戳转换的几种方法详解
2019/06/04 Python
python根据多个文件名批量查找文件
2019/08/13 Python
python计算无向图节点度的实例代码
2019/11/22 Python
python except异常处理之后不退出,解决异常继续执行的实现
2020/04/25 Python
python使用nibabel和sitk读取保存nii.gz文件实例
2020/07/01 Python
澳大利亚家具和家居用品在线商店:Interiors Online
2018/03/05 全球购物
美国酒店控股公司:Choice Hotels
2018/06/15 全球购物
Marlies Dekkers内衣法国官方网上商店:国际知名的荷兰内衣品牌
2019/03/18 全球购物
优秀共青团员事迹材料
2014/12/25 职场文书
详解Redis集群搭建的三种方式
2021/05/31 Redis