在pytorch中动态调整优化器的学习率方式


Posted in Python onJune 24, 2020

在深度学习中,经常需要动态调整学习率,以达到更好地训练效果,本文纪录在pytorch中的实现方法,其优化器实例为SGD优化器,其他如Adam优化器同样适用。

一般来说,在以SGD优化器作为基本优化器,然后根据epoch实现学习率指数下降,代码如下:

step = [10,20,30,40]
base_lr = 1e-4
sgd_opt = torch.optim.SGD(model.parameters(), lr=base_lr, nesterov=True, momentum=0.9)
def adjust_lr(epoch):
 lr = base_lr * (0.1 ** np.sum(epoch >= np.array(step)))
 for params_group in sgd_opt.param_groups:
  params_group['lr'] = lr
 return lr

只需要在每个train的epoch之前使用这个函数即可。

for epoch in range(60):
 model.train()
 adjust_lr(epoch)
 for ind, each in enumerate(train_loader):
 mat, label = each
 ...

补充知识:Pytorch框架下应用Bi-LSTM实现汽车评论文本关键词抽取

需要调用的模块及整体Bi-lstm流程

import torch
import pandas as pd
import numpy as np
from tensorflow import keras
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import gensim
from sklearn.model_selection import train_test_split
class word_extract(nn.Module):
 def __init__(self,d_model,embedding_matrix):
  super(word_extract, self).__init__()
  self.d_model=d_model
  self.embedding=nn.Embedding(num_embeddings=len(embedding_matrix),embedding_dim=200)
  self.embedding.weight.data.copy_(embedding_matrix)
  self.embedding.weight.requires_grad=False
  self.lstm1=nn.LSTM(input_size=200,hidden_size=50,bidirectional=True)
  self.lstm2=nn.LSTM(input_size=2*self.lstm1.hidden_size,hidden_size=50,bidirectional=True)
  self.linear=nn.Linear(2*self.lstm2.hidden_size,4)

 def forward(self,x):
  w_x=self.embedding(x)
  first_x,(first_h_x,first_c_x)=self.lstm1(w_x)
  second_x,(second_h_x,second_c_x)=self.lstm2(first_x)
  output_x=self.linear(second_x)
  return output_x

将文本转换为数值形式

def trans_num(word2idx,text):
 text_list=[]
 for i in text:
  s=i.rstrip().replace('\r','').replace('\n','').split(' ')
  numtext=[word2idx[j] if j in word2idx.keys() else word2idx['_PAD'] for j in s ]
  text_list.append(numtext)
 return text_list

将Gensim里的词向量模型转为矩阵形式,后续导入到LSTM模型中

def establish_word2vec_matrix(model): #负责将数值索引转为要输入的数据
 word2idx = {"_PAD": 0} # 初始化 `[word : token]` 字典,后期 tokenize 语料库就是用该词典。
 num2idx = {0: "_PAD"}
 vocab_list = [(k, model.wv[k]) for k, v in model.wv.vocab.items()]

 # 存储所有 word2vec 中所有向量的数组,留意其中多一位,词向量全为 0, 用于 padding
 embeddings_matrix = np.zeros((len(model.wv.vocab.items()) + 1, model.vector_size))
 for i in range(len(vocab_list)):
  word = vocab_list[i][0]
  word2idx[word] = i + 1
  num2idx[i + 1] = word
  embeddings_matrix[i + 1] = vocab_list[i][1]
 embeddings_matrix = torch.Tensor(embeddings_matrix)
 return embeddings_matrix, word2idx, num2idx

训练过程

def train(model,epoch,learning_rate,batch_size,x, y, val_x, val_y):
 optimizor = optim.Adam(model.parameters(), lr=learning_rate)
 data = TensorDataset(x, y)
 data = DataLoader(data, batch_size=batch_size)
 for i in range(epoch):
  for j, (per_x, per_y) in enumerate(data):
   output_y = model(per_x)
   loss = F.cross_entropy(output_y.view(-1,output_y.size(2)), per_y.view(-1))
   optimizor.zero_grad()
   loss.backward()
   optimizor.step()
   arg_y=output_y.argmax(dim=2)
   fit_correct=(arg_y==per_y).sum()
   fit_acc=fit_correct.item()/(per_y.size(0)*per_y.size(1))
   print('##################################')
   print('第{}次迭代第{}批次的训练误差为{}'.format(i + 1, j + 1, loss), end=' ')
   print('第{}次迭代第{}批次的训练准确度为{}'.format(i + 1, j + 1, fit_acc))
   val_output_y = model(val_x)
   val_loss = F.cross_entropy(val_output_y.view(-1,val_output_y.size(2)), val_y.view(-1))
   arg_val_y=val_output_y.argmax(dim=2)
   val_correct=(arg_val_y==val_y).sum()
   val_acc=val_correct.item()/(val_y.size(0)*val_y.size(1))
   print('第{}次迭代第{}批次的预测误差为{}'.format(i + 1, j + 1, val_loss), end=' ')
   print('第{}次迭代第{}批次的预测准确度为{}'.format(i + 1, j + 1, val_acc))
 torch.save(model,'./extract_model.pkl')#保存模型

主函数部分

if __name__ =='__main__':
 #生成词向量矩阵
 word2vec = gensim.models.Word2Vec.load('./word2vec_model')
 embedding_matrix,word2idx,num2idx=establish_word2vec_matrix(word2vec)#输入的是词向量模型
 #
 train_data=pd.read_csv('./数据.csv')
 x=list(train_data['文本'])
 # 将文本从文字转化为数值,这部分trans_num函数你需要自己改动去适应你自己的数据集
 x=trans_num(word2idx,x)
 #x需要先进行填充,也就是每个句子都是一样长度,不够长度的以0来填充,填充词单独分为一类
 # #也就是说输入的x是固定长度的数值列表,例如[50,123,1850,21,199,0,0,...]
 #输入的y是[2,0,1,0,0,1,3,3,3,3,3,.....]
 #填充代码你自行编写,以下部分是针对我的数据集
 x=keras.preprocessing.sequence.pad_sequences(
   x,maxlen=60,value=0,padding='post',
 )
 y=list(train_data['BIO数值'])
 y_text=[]
 for i in y:
  s=i.rstrip().split(' ')
  numtext=[int(j) for j in s]
  y_text.append(numtext)
 y=y_text
 y=keras.preprocessing.sequence.pad_sequences(
   y,maxlen=60,value=3,padding='post',
  )
 # 将数据进行划分
 fit_x,val_x,fit_y,val_y=train_test_split(x,y,train_size=0.8,test_size=0.2)
 fit_x=torch.LongTensor(fit_x)
 fit_y=torch.LongTensor(fit_y)
 val_x=torch.LongTensor(val_x)
 val_y=torch.LongTensor(val_y)
 #开始应用
 w_extract=word_extract(d_model=200,embedding_matrix=embedding_matrix)
 train(model=w_extract,epoch=5,learning_rate=0.001,batch_size=50,
   x=fit_x,y=fit_y,val_x=val_x,val_y=val_y)#可以自行改动参数,设置学习率,批次,和迭代次数
 w_extract=torch.load('./extract_model.pkl')#加载保存好的模型
 pred_val_y=w_extract(val_x).argmax(dim=2)

以上这篇在pytorch中动态调整优化器的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的__new__与__init__魔术方法理解笔记
Nov 08 Python
Python使用requests及BeautifulSoup构建爬虫实例代码
Jan 24 Python
Python cookbook(数据结构与算法)找出序列中出现次数最多的元素算法示例
Mar 15 Python
python如何爬取个性签名
Jun 19 Python
Django中数据库的数据关系:一对一,一对多,多对多
Oct 21 Python
详解Python3中ceil()函数用法
Feb 19 Python
sublime3之内网安装python插件Anaconda的流程
Nov 10 Python
利用python+request通过接口实现人员通行记录上传功能
Jan 13 Python
Python .py生成.pyd文件并打包.exe 的注意事项说明
Mar 04 Python
使用Python的开发框架Brownie部署以太坊智能合约
May 28 Python
pycharm部署django项目到云服务器的详细流程
Jun 29 Python
Python 详解通过Scrapy框架实现爬取CSDN全站热榜标题热词流程
Nov 11 Python
CentOS 7如何实现定时执行python脚本
Jun 24 #Python
python tkiner实现 一个小小的图片翻页功能的示例代码
Jun 24 #Python
在tensorflow实现直接读取网络的参数(weight and bias)的值
Jun 24 #Python
基于pytorch中的Sequential用法说明
Jun 24 #Python
django haystack实现全文检索的示例代码
Jun 24 #Python
Python爬虫如何应对Cloudflare邮箱加密
Jun 24 #Python
python使用自定义钉钉机器人的示例代码
Jun 24 #Python
You might like
php 短链接算法收集与分析
2011/12/30 PHP
PHP中error_log()函数的使用方法
2015/01/20 PHP
Laravel最佳分割路由文件(routes.php)的方式
2016/08/04 PHP
PHP中Notice错误常见解决方法
2017/04/28 PHP
Javascript注入技巧
2007/06/22 Javascript
javascript测试题练习代码
2012/10/10 Javascript
jQuery DOM插入节点操作指南
2015/03/03 Javascript
ionic在开发ios系统微信时键盘挡住输入框的解决方法(键盘弹出问题)
2016/09/06 Javascript
简单的渐变轮播插件
2017/01/12 Javascript
jQuery给表格添加分页效果
2017/03/02 Javascript
JS实现上传图片的三种方法并实现预览图片功能
2017/07/14 Javascript
vue实现前进刷新后退不刷新效果
2018/01/26 Javascript
vue-cli history模式实现tomcat部署报404的解决方式
2019/09/06 Javascript
python3.7.0的安装步骤
2018/08/27 Python
Python设计模式之工厂方法模式实例详解
2019/01/18 Python
解决pycharm的Python console不能调试当前程序的问题
2019/01/20 Python
利用python Selenium实现自动登陆京东签到领金币功能
2019/10/31 Python
Scrapy+Selenium自动获取cookie爬取网易云音乐个人喜爱歌单
2021/02/01 Python
viagogo法国票务平台:演唱会、体育比赛、戏剧门票
2017/03/27 全球购物
Bally美国官网:经典瑞士鞋履、手袋及配饰奢侈品牌
2018/05/18 全球购物
数以千计的折扣工业产品:ESE Direct
2018/05/20 全球购物
美国最大的电子宠物训练产品制造商:PetSafe
2018/10/12 全球购物
微软加拿大官方网站:Microsoft Canada
2019/04/28 全球购物
三星俄罗斯授权在线商店:Samsung俄罗斯
2019/09/28 全球购物
美国第一大药店连锁机构:Walgreens(沃尔格林)
2019/10/10 全球购物
JACK & JONES荷兰官网:男士服装和鞋子
2021/03/07 全球购物
abstract 可以和 virtual 一起使用吗?可以和 override 一起使用吗?
2012/10/15 面试题
高级文秘工作总结的自我评价
2013/09/28 职场文书
参赛口号
2014/06/16 职场文书
爱国口号
2014/06/19 职场文书
预备党员自我批评思想汇报
2014/10/10 职场文书
对外汉语专业大学生职业生涯规划书
2014/10/11 职场文书
2016小学新学期寄语
2015/12/04 职场文书
解决numpy数组互换两行及赋值的问题
2021/04/17 Python
一文搞懂python异常处理、模块与包
2021/06/26 Python
实现一个简单得数据响应系统
2021/11/11 Javascript