PyTorch预训练Bert模型的示例


Posted in Python onNovember 17, 2020

本文介绍以下内容:
1. 使用transformers框架做预训练的bert-base模型;
2. 开发平台使用Google的Colab平台,白嫖GPU加速;
3. 使用datasets模块下载IMDB影评数据作为训练数据。

transformers模块简介

transformers框架为Huggingface开源的深度学习框架,支持几乎所有的Transformer架构的预训练模型。使用非常的方便,本文基于此框架,尝试一下预训练模型的使用,简单易用。

本来打算预训练bert-large模型,发现colab上GPU显存不够用,只能使用base版本了。打开colab,并且设置好GPU加速,接下来开始介绍代码。

代码实现

首先安装数据下载模块和transformers包。

pip install datasets
pip install transformers

使用datasets下载IMDB数据,返回DatasetDict类型的数据.返回的数据是文本类型,需要进行编码。下面会使用tokenizer进行编码。

from datasets import load_dataset

imdb = load_dataset('imdb')
print(imdb['train'][:3]) # 打印前3条训练数据

接下来加载tokenizer和模型.从transformers导入AutoModelForSequenceClassification, AutoTokenizer,创建模型和tokenizer。

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_checkpoint = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)

对原始数据进行编码,并且分批次(batch)

def preprocessing_func(examples):
  return tokenizer(examples['text'], 
           padding=True,
           truncation=True, max_length=300)

batch_size = 16

encoded_data = imdb.map(preprocessing_func, batched=True, batch_size=batch_size)

上面得到编码数据,每个批次设置为16.接下来需要指定训练的参数,训练参数的指定使用transformers给出的接口类TrainingArguments,模型的训练可以使用Trainer。

from transformers import Trainer, TrainingArguments

args = TrainingArguments(
  'out',
  per_device_train_batch_size=batch_size,
  per_device_eval_batch_size=batch_size,
  learning_rate=5e-5,
  evaluation_strategy='epoch',
  num_train_epochs=10,
  load_best_model_at_end=True,
)

trainer = Trainer(
  model,
  args=args,
  train_dataset=encoded_data['train'],
  eval_dataset=encoded_data['test'],
  tokenizer=tokenizer
)

训练模型使用trainer对象的train方法

trainer.train()

PyTorch预训练Bert模型的示例

评估模型使用trainer对象的evaluate方法

trainer.evaluate()

总结

本文介绍了基于transformers框架实现的bert预训练模型,此框架提供了非常友好的接口,可以方便读者尝试各种预训练模型。同时datasets也提供了很多数据集,便于学习NLP的各种问题。加上Google提供的colab环境,数据下载和预训练模型下载都非常快,建议读者自行去炼丹。本文完整的案例下载

以上就是PyTorch预训练Bert模型的示例的详细内容,更多关于PyTorch预训练Bert模型的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
从零学Python之入门(五)缩进和选择
May 27 Python
Python简明入门教程
Aug 04 Python
Python的Tornado框架的异步任务与AsyncHTTPClient
Jun 27 Python
python 异或加密字符串的实例
Oct 14 Python
Python装饰器用法实例分析
Jan 14 Python
Python画图实现同一结点多个柱状图的示例
Jul 07 Python
Python项目 基于Scapy实现SYN泛洪攻击的方法
Jul 23 Python
Python字符串、列表、元组、字典、集合的补充实例详解
Dec 20 Python
Python实现AI换脸功能
Apr 10 Python
Python使用文件操作实现一个XX信息管理系统的示例
Jul 02 Python
Python使用xpath实现图片爬取
Sep 16 Python
python实现简单贪吃蛇游戏
Sep 29 Python
python 下载文件的多种方法汇总
Nov 17 #Python
python跨文件使用全局变量的实现
Nov 17 #Python
Python中logging日志的四个等级和使用
Nov 17 #Python
Python爬虫破解登陆哔哩哔哩的方法
Nov 17 #Python
appium+python自动化配置(adk、jdk、node.js)
Nov 17 #Python
python调用百度API实现人脸识别
Nov 17 #Python
详解利用python识别图片中的条码(pyzbar)及条码图片矫正和增强
Nov 17 #Python
You might like
php数组函数序列之array_combine() - 数组合并函数使用说明
2011/10/29 PHP
phalcon model在插入或更新时会自动验证非空字段的解决办法
2016/12/29 PHP
jquery 模式对话框终极版实现代码
2009/09/28 Javascript
jquery 常用操作整理 基础入门篇
2009/10/14 Javascript
javascript KeyDown、KeyPress和KeyUp事件的区别与联系
2009/12/03 Javascript
js 单击式的下拉菜单效果实例
2013/08/13 Javascript
JavaScript数组函数unshift、shift、pop、push使用实例
2014/08/27 Javascript
DOM基础教程之使用DOM
2015/01/19 Javascript
Javascript变量的作用域和作用域链详解
2015/04/02 Javascript
javascript随机抽取0-100之间不重复的10个数
2016/02/25 Javascript
javascript中eval解析JSON字符串
2016/02/27 Javascript
JS随机洗牌算法之数组随机排序
2016/03/23 Javascript
如何用JavaScript实现动态修改CSS样式表
2016/05/20 Javascript
Node.js的Koa框架上手及MySQL操作指南
2016/06/13 Javascript
JavaScript仿flash遮罩动画效果
2016/06/15 Javascript
ionic在开发ios系统微信时键盘挡住输入框的解决方法(键盘弹出问题)
2016/09/06 Javascript
Ajax跨域实现代码(后台jsp)
2017/01/21 Javascript
nodejs实现OAuth2.0授权服务认证
2017/12/27 NodeJs
详解vue2.0 不同屏幕适配及px与rem转换问题
2018/02/23 Javascript
在vue中多次调用同一个定义全局变量的实例
2018/09/25 Javascript
javascript中join方法实例讲解
2019/02/21 Javascript
微信小程序自定义组件实现环形进度条
2020/11/17 Javascript
Vuex的各个模块封装的实现
2020/06/05 Javascript
Linux下将Python的Django项目部署到Apache服务器
2015/12/24 Python
详解python的ORM中Pony用法
2018/02/09 Python
pytorch 转换矩阵的维数位置方法
2018/12/08 Python
python实现Excel文件转换为TXT文件
2019/04/28 Python
Python urlopen()参数代码示例解析
2020/12/10 Python
雅高酒店中国:Accorhotels.com China
2018/03/26 全球购物
英国在线自行车店:Merlin Cycles
2018/08/20 全球购物
大学生自荐信
2013/12/11 职场文书
简单通用的简历自我评价
2014/09/21 职场文书
鸦片战争观后感
2015/06/09 职场文书
Python实现机器学习算法的分类
2021/06/03 Python
PostgreSQL怎么创建分区表详解
2022/06/25 PostgreSQL
MySQL一劳永逸永久支持输入中文的方法实例
2022/08/05 MySQL