PyTorch预训练Bert模型的示例


Posted in Python onNovember 17, 2020

本文介绍以下内容:
1. 使用transformers框架做预训练的bert-base模型;
2. 开发平台使用Google的Colab平台,白嫖GPU加速;
3. 使用datasets模块下载IMDB影评数据作为训练数据。

transformers模块简介

transformers框架为Huggingface开源的深度学习框架,支持几乎所有的Transformer架构的预训练模型。使用非常的方便,本文基于此框架,尝试一下预训练模型的使用,简单易用。

本来打算预训练bert-large模型,发现colab上GPU显存不够用,只能使用base版本了。打开colab,并且设置好GPU加速,接下来开始介绍代码。

代码实现

首先安装数据下载模块和transformers包。

pip install datasets
pip install transformers

使用datasets下载IMDB数据,返回DatasetDict类型的数据.返回的数据是文本类型,需要进行编码。下面会使用tokenizer进行编码。

from datasets import load_dataset

imdb = load_dataset('imdb')
print(imdb['train'][:3]) # 打印前3条训练数据

接下来加载tokenizer和模型.从transformers导入AutoModelForSequenceClassification, AutoTokenizer,创建模型和tokenizer。

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_checkpoint = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)

对原始数据进行编码,并且分批次(batch)

def preprocessing_func(examples):
  return tokenizer(examples['text'], 
           padding=True,
           truncation=True, max_length=300)

batch_size = 16

encoded_data = imdb.map(preprocessing_func, batched=True, batch_size=batch_size)

上面得到编码数据,每个批次设置为16.接下来需要指定训练的参数,训练参数的指定使用transformers给出的接口类TrainingArguments,模型的训练可以使用Trainer。

from transformers import Trainer, TrainingArguments

args = TrainingArguments(
  'out',
  per_device_train_batch_size=batch_size,
  per_device_eval_batch_size=batch_size,
  learning_rate=5e-5,
  evaluation_strategy='epoch',
  num_train_epochs=10,
  load_best_model_at_end=True,
)

trainer = Trainer(
  model,
  args=args,
  train_dataset=encoded_data['train'],
  eval_dataset=encoded_data['test'],
  tokenizer=tokenizer
)

训练模型使用trainer对象的train方法

trainer.train()

PyTorch预训练Bert模型的示例

评估模型使用trainer对象的evaluate方法

trainer.evaluate()

总结

本文介绍了基于transformers框架实现的bert预训练模型,此框架提供了非常友好的接口,可以方便读者尝试各种预训练模型。同时datasets也提供了很多数据集,便于学习NLP的各种问题。加上Google提供的colab环境,数据下载和预训练模型下载都非常快,建议读者自行去炼丹。本文完整的案例下载

以上就是PyTorch预训练Bert模型的示例的详细内容,更多关于PyTorch预训练Bert模型的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python3基础之函数用法
Aug 13 Python
python绘制双柱形图代码实例
Dec 14 Python
python爬虫爬取淘宝商品信息(selenum+phontomjs)
Feb 24 Python
Pycharm 设置自定义背景颜色的图文教程
May 23 Python
Pandas读写CSV文件的方法示例
Mar 27 Python
python使用pandas抽样训练数据中某个类别实例
Feb 28 Python
利用python画出AUC曲线的实例
Feb 28 Python
150行Python代码实现带界面的数独游戏
Apr 04 Python
Python Django搭建网站流程图解
Jun 13 Python
Python函数递归调用实现原理实例解析
Aug 11 Python
python em算法的实现
Oct 03 Python
基于Python爬取京东双十一商品价格曲线
Oct 23 Python
python 下载文件的多种方法汇总
Nov 17 #Python
python跨文件使用全局变量的实现
Nov 17 #Python
Python中logging日志的四个等级和使用
Nov 17 #Python
Python爬虫破解登陆哔哩哔哩的方法
Nov 17 #Python
appium+python自动化配置(adk、jdk、node.js)
Nov 17 #Python
python调用百度API实现人脸识别
Nov 17 #Python
详解利用python识别图片中的条码(pyzbar)及条码图片矫正和增强
Nov 17 #Python
You might like
SWFUpload与CI不能正确上传识别文件MIME类型解决方法分享
2011/04/18 PHP
PHP运行模式的深入理解
2013/06/03 PHP
php创建基本身份认证站点的方法详解
2013/06/08 PHP
php header函数的常用http头设置
2015/06/25 PHP
php版微信自定义回复功能示例
2016/12/05 PHP
用javascript实现的仿Flash广告图片轮换效果
2007/04/24 Javascript
JavaScript 小型打飞机游戏实现原理说明
2010/10/28 Javascript
jQuery 选择器、DOM操作、事件、动画
2010/11/25 Javascript
javascript模版引擎-tmpl的bug修复与性能优化分析
2011/10/23 Javascript
httpclient模拟登陆具体实现(使用js设置cookie)
2013/12/11 Javascript
setTimeout内不支持jquery的选择器的解决方案
2015/04/28 Javascript
jQuery同步提交示例代码
2015/12/12 Javascript
AngularJS使用ngOption实现下拉列表的实例代码
2016/01/23 Javascript
手机端图片缩放旋转全屏查看PhotoSwipe.js插件实现
2016/08/25 Javascript
Sequelize中用group by进行分组聚合查询
2016/12/12 Javascript
简单实现JS上传图片预览功能
2017/04/14 Javascript
vue插件vue-resource的使用笔记(小结)
2017/08/04 Javascript
Django使用多数据库的方法
2017/09/06 Javascript
详解JS中统计函数执行次数与执行时间
2018/09/04 Javascript
如何在vue里面优雅的解决跨域(路由冲突问题)
2019/01/20 Javascript
Vue的v-model的几种修饰符.lazy,.number和.trim的用法说明
2020/08/05 Javascript
python发送HTTP请求的方法小结
2015/07/08 Python
Django框架中方法的访问和查找
2015/07/15 Python
使用Anaconda3建立虚拟独立的python2.7环境方法
2018/06/11 Python
Python设计模式之代理模式实例详解
2019/01/19 Python
python爬虫 Pyppeteer使用方法解析
2019/09/28 Python
python 单线程和异步协程工作方式解析
2019/09/28 Python
在tensorflow中设置使用某一块GPU、多GPU、CPU的操作
2020/02/07 Python
解决Django响应JsonResponse返回json格式数据报错问题
2020/08/09 Python
英国Zoro工具:手动工具,电动工具和个人防护用品
2016/11/02 全球购物
哈萨克斯坦最大的时装、鞋子和配饰在线商店:Lamoda.kz
2019/11/19 全球购物
使用C#编写创建一个线程的代码
2013/01/22 面试题
品恩科技软件测试面试题
2014/10/26 面试题
酒店前台辞职书
2015/02/26 职场文书
预备党员的思想汇报,你真的会写吗?
2019/06/28 职场文书
CSS 伪元素::marker详解
2021/06/26 HTML / CSS