PyTorch预训练Bert模型的示例


Posted in Python onNovember 17, 2020

本文介绍以下内容:
1. 使用transformers框架做预训练的bert-base模型;
2. 开发平台使用Google的Colab平台,白嫖GPU加速;
3. 使用datasets模块下载IMDB影评数据作为训练数据。

transformers模块简介

transformers框架为Huggingface开源的深度学习框架,支持几乎所有的Transformer架构的预训练模型。使用非常的方便,本文基于此框架,尝试一下预训练模型的使用,简单易用。

本来打算预训练bert-large模型,发现colab上GPU显存不够用,只能使用base版本了。打开colab,并且设置好GPU加速,接下来开始介绍代码。

代码实现

首先安装数据下载模块和transformers包。

pip install datasets
pip install transformers

使用datasets下载IMDB数据,返回DatasetDict类型的数据.返回的数据是文本类型,需要进行编码。下面会使用tokenizer进行编码。

from datasets import load_dataset

imdb = load_dataset('imdb')
print(imdb['train'][:3]) # 打印前3条训练数据

接下来加载tokenizer和模型.从transformers导入AutoModelForSequenceClassification, AutoTokenizer,创建模型和tokenizer。

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model_checkpoint = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)

对原始数据进行编码,并且分批次(batch)

def preprocessing_func(examples):
  return tokenizer(examples['text'], 
           padding=True,
           truncation=True, max_length=300)

batch_size = 16

encoded_data = imdb.map(preprocessing_func, batched=True, batch_size=batch_size)

上面得到编码数据,每个批次设置为16.接下来需要指定训练的参数,训练参数的指定使用transformers给出的接口类TrainingArguments,模型的训练可以使用Trainer。

from transformers import Trainer, TrainingArguments

args = TrainingArguments(
  'out',
  per_device_train_batch_size=batch_size,
  per_device_eval_batch_size=batch_size,
  learning_rate=5e-5,
  evaluation_strategy='epoch',
  num_train_epochs=10,
  load_best_model_at_end=True,
)

trainer = Trainer(
  model,
  args=args,
  train_dataset=encoded_data['train'],
  eval_dataset=encoded_data['test'],
  tokenizer=tokenizer
)

训练模型使用trainer对象的train方法

trainer.train()

PyTorch预训练Bert模型的示例

评估模型使用trainer对象的evaluate方法

trainer.evaluate()

总结

本文介绍了基于transformers框架实现的bert预训练模型,此框架提供了非常友好的接口,可以方便读者尝试各种预训练模型。同时datasets也提供了很多数据集,便于学习NLP的各种问题。加上Google提供的colab环境,数据下载和预训练模型下载都非常快,建议读者自行去炼丹。本文完整的案例下载

以上就是PyTorch预训练Bert模型的示例的详细内容,更多关于PyTorch预训练Bert模型的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python3 能振兴 Python的原因分析
Nov 28 Python
python select.select模块通信全过程解析
Sep 20 Python
Python设计模式之中介模式简单示例
Jan 09 Python
django静态文件加载的方法
May 20 Python
如何在python字符串中输入纯粹的{}
Aug 22 Python
对python使用telnet实现弱密码登录的方法详解
Jan 26 Python
python写日志文件操作类与应用示例
Jul 01 Python
python实现的爬取电影下载链接功能示例
Aug 26 Python
django连接oracle时setting 配置方法
Aug 29 Python
Python如何将图像音视频等资源文件隐藏在代码中(小技巧)
Feb 16 Python
windows上彻底删除jupyter notebook的实现
Apr 13 Python
Python填充任意颜色,不同算法时间差异分析说明
May 16 Python
python 下载文件的多种方法汇总
Nov 17 #Python
python跨文件使用全局变量的实现
Nov 17 #Python
Python中logging日志的四个等级和使用
Nov 17 #Python
Python爬虫破解登陆哔哩哔哩的方法
Nov 17 #Python
appium+python自动化配置(adk、jdk、node.js)
Nov 17 #Python
python调用百度API实现人脸识别
Nov 17 #Python
详解利用python识别图片中的条码(pyzbar)及条码图片矫正和增强
Nov 17 #Python
You might like
Session保存到数据库的php类分享
2011/10/24 PHP
php 中文字符串首字母的获取函数分享
2013/11/04 PHP
windows中为php安装mongodb与memcache
2015/01/06 PHP
php数组随机排序实现方法
2015/06/13 PHP
通过PHP自带的服务器来查看正则匹配结果的方法
2015/12/24 PHP
Joomla语言翻译类Jtext用法分析
2016/05/05 PHP
thinkPHP框架实现的简单计算器示例
2018/12/07 PHP
微信公众平台开发教程④ ThinkPHP框架下微信支付功能图文详解
2019/04/10 PHP
学习js所必须要知道的一些
2007/03/07 Javascript
JS返回上一页实例代码通过图片和按钮分别实现
2013/08/16 Javascript
js超时调用setTimeout和间歇调用setInterval实例分析
2015/01/28 Javascript
jQuery.position()方法获取不到值的安全替换方法
2015/03/13 Javascript
javascript仿百度输入框提示自动下拉补全
2016/01/07 Javascript
js html5 css俄罗斯方块游戏再现
2016/10/17 Javascript
AngularJS监听ng-repeat渲染完成的方法
2018/03/20 Javascript
vue 实现在函数中触发路由跳转的示例
2018/09/01 Javascript
JS获取并处理php数组的方法实例分析
2018/09/04 Javascript
React降级配置及Ant Design配置详解
2018/12/27 Javascript
vue.js使用v-model实现表单元素(input) 双向数据绑定功能示例
2019/03/08 Javascript
[00:49]完美世界DOTA2联赛10月28日开团时刻:随便打
2020/10/29 DOTA
Python安装官方whl包和tar.gz包的方法(推荐)
2017/06/04 Python
pygame游戏之旅 创建游戏窗口界面
2018/11/20 Python
python tkinter库实现气泡屏保和锁屏
2019/07/29 Python
python实现大战外星人小游戏实例代码
2019/12/26 Python
python 如何把docker-compose.yaml导入到数据库相关条目里
2021/01/15 Python
SmartBuyGlasses台湾:名牌眼镜,名牌太阳眼镜及隐形眼镜
2017/01/04 全球购物
美国奢侈品在线团购网站:Gilt City
2017/11/16 全球购物
以下为Windows NT 下的32 位C++程序,请计算sizeof 的值
2016/12/07 面试题
一套比较完整的软件测试人员面试题
2012/05/13 面试题
大学生社会实践方案
2014/05/11 职场文书
党员个人对照检查材料范文
2014/09/24 职场文书
党的群众路线教育实践活动教师自我剖析材料
2014/10/09 职场文书
导游欢送词
2015/01/31 职场文书
房屋产权证明书
2015/06/19 职场文书
学习计划是什么
2019/04/30 职场文书
优秀范文:读《红岩》有感3篇
2019/10/14 职场文书