在pytorch中动态调整优化器的学习率方式


Posted in Python onJune 24, 2020

在深度学习中,经常需要动态调整学习率,以达到更好地训练效果,本文纪录在pytorch中的实现方法,其优化器实例为SGD优化器,其他如Adam优化器同样适用。

一般来说,在以SGD优化器作为基本优化器,然后根据epoch实现学习率指数下降,代码如下:

step = [10,20,30,40]
base_lr = 1e-4
sgd_opt = torch.optim.SGD(model.parameters(), lr=base_lr, nesterov=True, momentum=0.9)
def adjust_lr(epoch):
 lr = base_lr * (0.1 ** np.sum(epoch >= np.array(step)))
 for params_group in sgd_opt.param_groups:
  params_group['lr'] = lr
 return lr

只需要在每个train的epoch之前使用这个函数即可。

for epoch in range(60):
 model.train()
 adjust_lr(epoch)
 for ind, each in enumerate(train_loader):
 mat, label = each
 ...

补充知识:Pytorch框架下应用Bi-LSTM实现汽车评论文本关键词抽取

需要调用的模块及整体Bi-lstm流程

import torch
import pandas as pd
import numpy as np
from tensorflow import keras
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import gensim
from sklearn.model_selection import train_test_split
class word_extract(nn.Module):
 def __init__(self,d_model,embedding_matrix):
  super(word_extract, self).__init__()
  self.d_model=d_model
  self.embedding=nn.Embedding(num_embeddings=len(embedding_matrix),embedding_dim=200)
  self.embedding.weight.data.copy_(embedding_matrix)
  self.embedding.weight.requires_grad=False
  self.lstm1=nn.LSTM(input_size=200,hidden_size=50,bidirectional=True)
  self.lstm2=nn.LSTM(input_size=2*self.lstm1.hidden_size,hidden_size=50,bidirectional=True)
  self.linear=nn.Linear(2*self.lstm2.hidden_size,4)

 def forward(self,x):
  w_x=self.embedding(x)
  first_x,(first_h_x,first_c_x)=self.lstm1(w_x)
  second_x,(second_h_x,second_c_x)=self.lstm2(first_x)
  output_x=self.linear(second_x)
  return output_x

将文本转换为数值形式

def trans_num(word2idx,text):
 text_list=[]
 for i in text:
  s=i.rstrip().replace('\r','').replace('\n','').split(' ')
  numtext=[word2idx[j] if j in word2idx.keys() else word2idx['_PAD'] for j in s ]
  text_list.append(numtext)
 return text_list

将Gensim里的词向量模型转为矩阵形式,后续导入到LSTM模型中

def establish_word2vec_matrix(model): #负责将数值索引转为要输入的数据
 word2idx = {"_PAD": 0} # 初始化 `[word : token]` 字典,后期 tokenize 语料库就是用该词典。
 num2idx = {0: "_PAD"}
 vocab_list = [(k, model.wv[k]) for k, v in model.wv.vocab.items()]

 # 存储所有 word2vec 中所有向量的数组,留意其中多一位,词向量全为 0, 用于 padding
 embeddings_matrix = np.zeros((len(model.wv.vocab.items()) + 1, model.vector_size))
 for i in range(len(vocab_list)):
  word = vocab_list[i][0]
  word2idx[word] = i + 1
  num2idx[i + 1] = word
  embeddings_matrix[i + 1] = vocab_list[i][1]
 embeddings_matrix = torch.Tensor(embeddings_matrix)
 return embeddings_matrix, word2idx, num2idx

训练过程

def train(model,epoch,learning_rate,batch_size,x, y, val_x, val_y):
 optimizor = optim.Adam(model.parameters(), lr=learning_rate)
 data = TensorDataset(x, y)
 data = DataLoader(data, batch_size=batch_size)
 for i in range(epoch):
  for j, (per_x, per_y) in enumerate(data):
   output_y = model(per_x)
   loss = F.cross_entropy(output_y.view(-1,output_y.size(2)), per_y.view(-1))
   optimizor.zero_grad()
   loss.backward()
   optimizor.step()
   arg_y=output_y.argmax(dim=2)
   fit_correct=(arg_y==per_y).sum()
   fit_acc=fit_correct.item()/(per_y.size(0)*per_y.size(1))
   print('##################################')
   print('第{}次迭代第{}批次的训练误差为{}'.format(i + 1, j + 1, loss), end=' ')
   print('第{}次迭代第{}批次的训练准确度为{}'.format(i + 1, j + 1, fit_acc))
   val_output_y = model(val_x)
   val_loss = F.cross_entropy(val_output_y.view(-1,val_output_y.size(2)), val_y.view(-1))
   arg_val_y=val_output_y.argmax(dim=2)
   val_correct=(arg_val_y==val_y).sum()
   val_acc=val_correct.item()/(val_y.size(0)*val_y.size(1))
   print('第{}次迭代第{}批次的预测误差为{}'.format(i + 1, j + 1, val_loss), end=' ')
   print('第{}次迭代第{}批次的预测准确度为{}'.format(i + 1, j + 1, val_acc))
 torch.save(model,'./extract_model.pkl')#保存模型

主函数部分

if __name__ =='__main__':
 #生成词向量矩阵
 word2vec = gensim.models.Word2Vec.load('./word2vec_model')
 embedding_matrix,word2idx,num2idx=establish_word2vec_matrix(word2vec)#输入的是词向量模型
 #
 train_data=pd.read_csv('./数据.csv')
 x=list(train_data['文本'])
 # 将文本从文字转化为数值,这部分trans_num函数你需要自己改动去适应你自己的数据集
 x=trans_num(word2idx,x)
 #x需要先进行填充,也就是每个句子都是一样长度,不够长度的以0来填充,填充词单独分为一类
 # #也就是说输入的x是固定长度的数值列表,例如[50,123,1850,21,199,0,0,...]
 #输入的y是[2,0,1,0,0,1,3,3,3,3,3,.....]
 #填充代码你自行编写,以下部分是针对我的数据集
 x=keras.preprocessing.sequence.pad_sequences(
   x,maxlen=60,value=0,padding='post',
 )
 y=list(train_data['BIO数值'])
 y_text=[]
 for i in y:
  s=i.rstrip().split(' ')
  numtext=[int(j) for j in s]
  y_text.append(numtext)
 y=y_text
 y=keras.preprocessing.sequence.pad_sequences(
   y,maxlen=60,value=3,padding='post',
  )
 # 将数据进行划分
 fit_x,val_x,fit_y,val_y=train_test_split(x,y,train_size=0.8,test_size=0.2)
 fit_x=torch.LongTensor(fit_x)
 fit_y=torch.LongTensor(fit_y)
 val_x=torch.LongTensor(val_x)
 val_y=torch.LongTensor(val_y)
 #开始应用
 w_extract=word_extract(d_model=200,embedding_matrix=embedding_matrix)
 train(model=w_extract,epoch=5,learning_rate=0.001,batch_size=50,
   x=fit_x,y=fit_y,val_x=val_x,val_y=val_y)#可以自行改动参数,设置学习率,批次,和迭代次数
 w_extract=torch.load('./extract_model.pkl')#加载保存好的模型
 pred_val_y=w_extract(val_x).argmax(dim=2)

以上这篇在pytorch中动态调整优化器的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
写了个监控nginx进程的Python脚本
May 10 Python
python格式化字符串实例总结
Sep 28 Python
Python复制目录结构脚本代码分享
Mar 06 Python
Python切片知识解析
Mar 06 Python
Python自动发邮件脚本
Mar 31 Python
python记录程序运行时间的三种方法
Jul 14 Python
python 日期操作类代码
May 05 Python
python requests 测试代理ip是否生效
Jul 25 Python
利用python实现凯撒密码加解密功能
Mar 31 Python
Python flask框架实现浏览器点击自定义跳转页面
Jun 04 Python
解决python运行效率不高的问题
Jul 20 Python
使用Python获取字典键对应值的方法
Apr 26 Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
You might like
咖啡豆的最常见发酵处理方法,详细了解一下
2021/03/03 冲泡冲煮
PHP读取RSS(Feed)简单实例
2014/06/12 PHP
CI框架自动加载session出现报错的解决办法
2014/06/17 PHP
php生成mysql的数据字典
2016/07/07 PHP
ThinkPHP简单使用memcache缓存的方法
2016/11/15 PHP
PHP函数rtrim()使用中的怪异现象分析
2017/02/24 PHP
PHP中的函数声明与使用详解
2017/05/27 PHP
一个很简单的jquery+xml+ajax的无刷新树结构(无css,后台是c#)
2010/06/02 Javascript
js监听键盘事件示例代码
2013/07/26 Javascript
javascript避免数字计算精度误差的方法详解
2014/03/05 Javascript
jquery easyui 对于开始时间小于结束时间的判断示例
2014/03/22 Javascript
JavaScript将一个数组插入到另一个数组的方法
2015/03/19 Javascript
表单验证正则表达式实例代码详解
2015/11/09 Javascript
图解prototype、proto和constructor的三角关系
2016/07/31 Javascript
Vue精简版风格指南(推荐)
2018/01/30 Javascript
微信小程序进入广告实现代码实例
2019/09/19 Javascript
原生javascript制作贪吃蛇小游戏的方法分析
2020/02/26 Javascript
JavaScript事件委托实现原理及优点进行
2020/08/29 Javascript
使用python统计文件行数示例分享
2014/02/21 Python
Python函数参数类型*、**的区别
2015/04/11 Python
Python2.7基于淘宝接口获取IP地址所在地理位置的方法【测试可用】
2017/06/07 Python
python检测IP地址变化并触发事件
2018/12/26 Python
Python闭包思想与用法浅析
2018/12/27 Python
分享一个pycharm专业版安装的永久使用方法
2019/09/24 Python
Python进程Multiprocessing模块原理解析
2020/02/28 Python
10个python3常用排序算法详细说明与实例(快速排序,冒泡排序,桶排序,基数排序,堆排序,希尔排序,归并排序,计数排序)
2020/03/17 Python
Michael Kors澳大利亚官网:世界知名的奢侈饰品和成衣设计师
2020/02/13 全球购物
PPP协议组成及简述协议协商的基本过程
2015/05/28 面试题
卫生巾广告词
2014/03/18 职场文书
企业文化建设实施方案
2014/03/22 职场文书
运动会方队口号
2014/06/07 职场文书
美容院合作经营协议书
2014/10/10 职场文书
社区服务活动感想
2015/08/11 职场文书
MySQL通过binlog恢复数据
2021/05/27 MySQL
简单总结SpringMVC拦截器的使用方法
2021/06/28 Java/Android
前端框架ECharts dataset对数据可视化的高级管理
2022/12/24 Javascript