PyTorch预训练Bert模型的示例


Posted in Python onNovember 17, 2020

本文介绍以下内容:
1. 使用transformers框架做预训练的bert-base模型;
2. 开发平台使用Google的Colab平台,白嫖GPU加速;
3. 使用datasets模块下载IMDB影评数据作为训练数据。

transformers模块简介

transformers框架为Huggingface开源的深度学习框架,支持几乎所有的Transformer架构的预训练模型。使用非常的方便,本文基于此框架,尝试一下预训练模型的使用,简单易用。

本来打算预训练bert-large模型,发现colab上GPU显存不够用,只能使用base版本了。打开colab,并且设置好GPU加速,接下来开始介绍代码。

代码实现

首先安装数据下载模块和transformers包。

pip install datasets
pip install transformers

使用datasets下载IMDB数据,返回DatasetDict类型的数据.返回的数据是文本类型,需要进行编码。下面会使用tokenizer进行编码。

from datasets import load_dataset

imdb = load_dataset('imdb')
print(imdb['train'][:3]) # 打印前3条训练数据

接下来加载tokenizer和模型.从transformers导入AutoModelForSequenceClassification, AutoTokenizer,创建模型和tokenizer。

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_checkpoint = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)

对原始数据进行编码,并且分批次(batch)

def preprocessing_func(examples):
  return tokenizer(examples['text'], 
           padding=True,
           truncation=True, max_length=300)

batch_size = 16

encoded_data = imdb.map(preprocessing_func, batched=True, batch_size=batch_size)

上面得到编码数据,每个批次设置为16.接下来需要指定训练的参数,训练参数的指定使用transformers给出的接口类TrainingArguments,模型的训练可以使用Trainer。

from transformers import Trainer, TrainingArguments

args = TrainingArguments(
  'out',
  per_device_train_batch_size=batch_size,
  per_device_eval_batch_size=batch_size,
  learning_rate=5e-5,
  evaluation_strategy='epoch',
  num_train_epochs=10,
  load_best_model_at_end=True,
)

trainer = Trainer(
  model,
  args=args,
  train_dataset=encoded_data['train'],
  eval_dataset=encoded_data['test'],
  tokenizer=tokenizer
)

训练模型使用trainer对象的train方法

trainer.train()

PyTorch预训练Bert模型的示例

评估模型使用trainer对象的evaluate方法

trainer.evaluate()

总结

本文介绍了基于transformers框架实现的bert预训练模型,此框架提供了非常友好的接口,可以方便读者尝试各种预训练模型。同时datasets也提供了很多数据集,便于学习NLP的各种问题。加上Google提供的colab环境,数据下载和预训练模型下载都非常快,建议读者自行去炼丹。本文完整的案例下载

以上就是PyTorch预训练Bert模型的示例的详细内容,更多关于PyTorch预训练Bert模型的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python线程的两种编程方式
Apr 14 Python
python操作ie登陆土豆网的方法
May 09 Python
学习python之编写简单乘法口诀表实现代码
Feb 27 Python
pandas获取groupby分组里最大值所在的行方法
Apr 20 Python
基于python实现自动化办公学习笔记(CSV、word、Excel、PPT)
Aug 06 Python
Python爬取破解无线网络wifi密码过程解析
Sep 17 Python
python进程的状态、创建及使用方法详解
Dec 06 Python
Python Scrapy框架第一个入门程序示例
Feb 05 Python
Python文件操作模拟用户登陆代码实例
Jun 09 Python
Python+Opencv身份证号码区域提取及识别实现
Aug 25 Python
Python爬虫爬取有道实现翻译功能
Nov 27 Python
python 基于opencv实现图像增强
Dec 23 Python
python 下载文件的多种方法汇总
Nov 17 #Python
python跨文件使用全局变量的实现
Nov 17 #Python
Python中logging日志的四个等级和使用
Nov 17 #Python
Python爬虫破解登陆哔哩哔哩的方法
Nov 17 #Python
appium+python自动化配置(adk、jdk、node.js)
Nov 17 #Python
python调用百度API实现人脸识别
Nov 17 #Python
详解利用python识别图片中的条码(pyzbar)及条码图片矫正和增强
Nov 17 #Python
You might like
第三节 定义一个类 [3]
2006/10/09 PHP
PHP输出控制功能在简繁体转换中的应用
2006/10/09 PHP
将OICQ数据转成MYSQL数据
2006/10/09 PHP
PHP中去除换行解决办法小结(PHP_EOL)
2011/11/27 PHP
php全角字符转换为半角函数
2014/02/07 PHP
PHP中SESSION的注销与清除
2015/04/16 PHP
php面向对象重点知识分享
2019/09/27 PHP
JS 实现完美include载入实现代码
2010/08/05 Javascript
ko knockoutjs动态属性绑定技巧应用
2012/11/14 Javascript
url参数中有+、空格、=、%、&、#等特殊符号的问题解决
2013/05/15 Javascript
关于IE中getElementsByClassName不能用的问题解决方法
2013/08/26 Javascript
浏览器兼容性问题大汇总
2015/12/17 Javascript
正则验证小数点后面只能有两位数的方法
2017/02/28 Javascript
Vue上传组件vue Simple Uploader的用法示例
2017/08/25 Javascript
vue组件父子间通信详解(三)
2017/11/07 Javascript
Vue官方文档梳理之全局配置
2017/11/22 Javascript
详解JS函数stack size计算方法
2018/06/18 Javascript
详解vue中移动端自适应方案
2019/05/05 Javascript
[47:52]DOTA2-DPC中国联赛正赛 iG vs LBZS BO3 第二场 3月4日
2021/03/11 DOTA
python基础教程之基本内置数据类型介绍
2014/02/20 Python
简单介绍Ruby中的CGI编程
2015/04/10 Python
Python中列表的一些基本操作知识汇总
2015/05/20 Python
Python实现删除文件但保留指定文件
2015/06/21 Python
Python彩色化Linux的命令行终端界面的代码实例分享
2016/07/02 Python
树莓派使用USB摄像头和motion实现监控
2019/06/22 Python
简单了解python的内存管理机制
2019/07/08 Python
Django自定义用户表+自定义admin后台中的字段实例
2019/11/18 Python
pytorch:实现简单的GAN示例(MNIST数据集)
2020/01/10 Python
解决PyCharm不在run输出运行结果而不是再Console里输出的问题
2020/09/21 Python
在HTML5 canvas里用卷积核进行图像处理的方法
2018/05/02 HTML / CSS
全球速卖通巴西站点:Aliexpress巴西
2016/08/24 全球购物
小学毕业寄语大全
2014/04/03 职场文书
博士毕业生自我鉴定范文
2014/04/13 职场文书
什么是求职信?求职信应包含哪些内容?
2019/08/14 职场文书
一劳永逸彻底解决pip install慢的办法
2021/05/24 Python
使用CSS实现音波加载效果
2023/05/07 HTML / CSS