pytorch下使用LSTM神经网络写诗实例


Posted in Python onJanuary 14, 2020

在pytorch下,以数万首唐诗为素材,训练双层LSTM神经网络,使其能够以唐诗的方式写诗。

代码结构分为四部分,分别为

1.model.py,定义了双层LSTM模型

2.data.py,定义了从网上得到的唐诗数据的处理方法

3.utlis.py 定义了损失可视化的函数

4.main.py定义了模型参数,以及训练、唐诗生成函数。

参考:电子工业出版社的《深度学习框架PyTorch:入门与实践》第九章

main代码及注释如下

import sys, os
import torch as t
from data import get_data
from model import PoetryModel
from torch import nn
from torch.autograd import Variable
from utils import Visualizer
import tqdm
from torchnet import meter
import ipdb
 
class Config(object):
	data_path = 'data/'
	pickle_path = 'tang.npz'
	author = None
	constrain = None
	category = 'poet.tang' #or poet.song
	lr = 1e-3
	weight_decay = 1e-4
	use_gpu = True
	epoch = 20
	batch_size = 128
	maxlen = 125
	plot_every = 20
	#use_env = True #是否使用visodm
	env = 'poety' 
	#visdom env
	max_gen_len = 200
	debug_file = '/tmp/debugp'
	model_path = None
	prefix_words = '细雨鱼儿出,微风燕子斜。' 
	#不是诗歌组成部分,是意境
	start_words = '闲云潭影日悠悠' 
	#诗歌开始
	acrostic = False 
	#是否藏头
	model_prefix = 'checkpoints/tang' 
	#模型保存路径
opt = Config()
 
def generate(model, start_words, ix2word, word2ix, prefix_words=None):
	'''
	给定几个词,根据这几个词接着生成一首完整的诗歌
	'''
	results = list(start_words)
	start_word_len = len(start_words)
	# 手动设置第一个词为<START>
	# 这个地方有问题,最后需要再看一下
	input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
	if opt.use_gpu:input=input.cuda()
	hidden = None
	
	if prefix_words:
		for word in prefix_words:
			output,hidden = model(input,hidden)
			# 下边这句话是为了把input变成1*1?
			input = Variable(input.data.new([word2ix[word]])).view(1,1)
	for i in range(opt.max_gen_len):
		output,hidden = model(input,hidden)
		
		if i<start_word_len:
			w = results[i]
			input = Variable(input.data.new([word2ix[w]])).view(1,1)
		else:
			top_index = output.data[0].topk(1)[1][0]
			w = ix2word[top_index]
			results.append(w)
			input = Variable(input.data.new([top_index])).view(1,1)
		if w=='<EOP>':
			del results[-1] #-1的意思是倒数第一个
			break
	return results
 
def gen_acrostic(model,start_words,ix2word,word2ix, prefix_words = None):
 '''
 生成藏头诗
 start_words : u'深度学习'
 生成:
 深木通中岳,青苔半日脂。
 度山分地险,逆浪到南巴。
 学道兵犹毒,当时燕不移。
 习根通古岸,开镜出清羸。
 '''
 results = []
 start_word_len = len(start_words)
 input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
 if opt.use_gpu:input=input.cuda()
 hidden = None
 
 index=0 # 用来指示已经生成了多少句藏头诗
 # 上一个词
 pre_word='<START>'
 
 if prefix_words:
  for word in prefix_words:
   output,hidden = model(input,hidden)
   input = Variable(input.data.new([word2ix[word]])).view(1,1)
 
 for i in range(opt.max_gen_len):
  output,hidden = model(input,hidden)
  top_index = output.data[0].topk(1)[1][0]
  w = ix2word[top_index]
 
  if (pre_word in {u'。',u'!','<START>'} ):
   # 如果遇到句号,藏头的词送进去生成
 
   if index==start_word_len:
    # 如果生成的诗歌已经包含全部藏头的词,则结束
    break
   else: 
    # 把藏头的词作为输入送入模型
    w = start_words[index]
    index+=1
    input = Variable(input.data.new([word2ix[w]])).view(1,1) 
  else:
   # 否则的话,把上一次预测是词作为下一个词输入
   input = Variable(input.data.new([word2ix[w]])).view(1,1)
  results.append(w)
  pre_word = w
 return results
 
def train(**kwargs):
	
	for k,v in kwargs.items():
		setattr(opt,k,v) #设置apt里属性的值
	vis = Visualizer(env=opt.env)
	
	#获取数据
	data, word2ix, ix2word = get_data(opt) #get_data是data.py里的函数
	data = t.from_numpy(data)
	#这个地方出错了,是大写的L
	dataloader = t.utils.data.DataLoader(data, 
					batch_size = opt.batch_size,
					shuffle = True,
					num_workers = 1) #在python里,这样写程序可以吗?
 #模型定义
	model = PoetryModel(len(word2ix), 128, 256)
	optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
	criterion = nn.CrossEntropyLoss()
 
	if opt.model_path:
		model.load_state_dict(t.load(opt.model_path))
	if opt.use_gpu:
		model.cuda()
		criterion.cuda()
		
	#The tnt.AverageValueMeter measures and returns the average value 
	#and the standard deviation of any collection of numbers that are 
	#added to it. It is useful, for instance, to measure the average 
	#loss over a collection of examples.
 
 #The add() function expects as input a Lua number value, which 
 #is the value that needs to be added to the list of values to 
 #average. It also takes as input an optional parameter n that 
 #assigns a weight to value in the average, in order to facilitate 
 #computing weighted averages (default = 1).
 
 #The tnt.AverageValueMeter has no parameters to be set at initialization time. 
	loss_meter = meter.AverageValueMeter()
	
	for epoch in range(opt.epoch):
		loss_meter.reset()
		for ii,data_ in tqdm.tqdm(enumerate(dataloader)):
			#tqdm是python中的进度条
			#训练
			data_ = data_.long().transpose(1,0).contiguous()
			#上边一句话,把data_变成long类型,把1维和0维转置,把内存调成连续的
			if opt.use_gpu: data_ = data_.cuda()
			optimizer.zero_grad()
			input_, target = Variable(data_[:-1,:]), Variable(data_[1:,:])
			#上边一句,将输入的诗句错开一个字,形成训练和目标
			output,_ = model(input_)
			loss = criterion(output, target.view(-1))
			loss.backward()
			optimizer.step()
			
			loss_meter.add(loss.data[0]) #为什么是data[0]?
			
			#可视化用到的是utlis.py里的函数
			if (1+ii)%opt.plot_every ==0:
				
				if os.path.exists(opt.debug_file):
					ipdb.set_trace()
				vis.plot('loss',loss_meter.value()[0])
				
				# 下面是对目前模型情况的测试,诗歌原文
				poetrys = [[ix2word[_word] for _word in data_[:,_iii]] 
									for _iii in range(data_.size(1))][:16]
				#上面句子嵌套了两个循环,主要是将诗歌索引的前十六个字变成原文
				vis.text('</br>'.join([''.join(poetry) for poetry in 
				poetrys]),win = u'origin_poem')
				gen_poetries = []
				#分别以以下几个字作为诗歌的第一个字,生成8首诗
				for word in list(u'春江花月夜凉如水'):
					gen_poetry = ''.join(generate(model,word,ix2word,word2ix))
					gen_poetries.append(gen_poetry)
				vis.text('</br>'.join([''.join(poetry) for poetry in 
				gen_poetries]), win = u'gen_poem')
		t.save(model.state_dict(), '%s_%s.pth' %(opt.model_prefix,epoch))
 
def gen(**kwargs):
	'''
	提供命令行接口,用以生成相应的诗
	'''
	
	for k,v in kwargs.items():
		setattr(opt,k,v)
	data, word2ix, ix2word = get_data(opt)
	model = PoetryModel(len(word2ix), 128, 256)
	map_location = lambda s,l:s
	# 上边句子里的map_location是在load里用的,用以加载到指定的CPU或GPU,
	# 上边句子的意思是将模型加载到默认的GPU上
	state_dict = t.load(opt.model_path, map_location = map_location)
	model.load_state_dict(state_dict)
	
	if opt.use_gpu:
		model.cuda()
	if sys.version_info.major == 3:
		if opt.start_words.insprintable():
			start_words = opt.start_words
			prefix_words = opt.prefix_words if opt.prefix_words else None
		else:
			start_words = opt.start_words.encode('ascii',\
			'surrogateescape').decode('utf8')
			prefix_words = opt.prefix_words.encode('ascii',\
			'surrogateescape').decode('utf8') if opt.prefix_words else None
		start_words = start_words.replace(',',u',')\
											.replace('.',u'。')\
											.replace('?',u'?')
		gen_poetry = gen_acrostic if opt.acrostic else generate
		result = gen_poetry(model,start_words,ix2word,word2ix,prefix_words)
		print(''.join(result))
if __name__ == '__main__':
	import fire
	fire.Fire()

以上代码给我一些经验,

1. 了解python的编程方式,如空格、换行等;进一步了解python的各个基本模块;

2. 可能出的错误:函数名写错,大小写,变量名写错,括号不全。

3. 对cuda()的用法有了进一步认识;

4. 学会了调试程序(fire);

5. 学会了训练结果的可视化(visdom);

6. 进一步的了解了LSTM,对深度学习的架构、实现有了宏观把控。

这篇pytorch下使用LSTM神经网络写诗实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python备份文件的脚本
Aug 11 Python
python实现从字符串中找出字符1的位置以及个数的方法
Aug 25 Python
Python中的集合类型知识讲解
Aug 19 Python
Python时间模块datetime、time、calendar的使用方法
Jan 13 Python
使用Python的Twisted框架构建非阻塞下载程序的实例教程
May 25 Python
python使用str &amp; repr转换字符串
Oct 13 Python
python读取与写入csv格式文件的示例代码
Dec 16 Python
python用quad、dblquad实现一维二维积分的实例详解
Nov 20 Python
Python爬虫爬取电影票房数据及图表展示操作示例
Mar 27 Python
解决pycharm下pyuic工具使用的问题
Apr 08 Python
Python填充任意颜色,不同算法时间差异分析说明
May 16 Python
python代码实现将列表中重复元素之间的内容全部滤除
May 22 Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
PyTorch实现ResNet50、ResNet101和ResNet152示例
Jan 14 #Python
python重要函数eval多种用法解析
Jan 14 #Python
You might like
PHP防盗链代码实例
2014/08/27 PHP
Zend Framework动作助手Json用法实例分析
2016/03/05 PHP
PHP清除缓存的几种方法总结
2017/09/12 PHP
php实现socket推送技术的示例
2017/12/20 PHP
PHP mkdir创建文件夹实现方法解析
2020/11/13 PHP
ie 调试javascript的工具
2009/04/29 Javascript
JqGrid web打印实现代码
2011/05/31 Javascript
JavaScript高级程序设计 学习笔记 js高级技巧
2011/09/20 Javascript
兼容ie、firefox的图片自动缩放的css跟js代码分享
2013/08/12 Javascript
javascript中对Attr(dom中属性)的操作示例讲解
2013/12/02 Javascript
javascript实现日期格式转换
2014/12/16 Javascript
jquery实现通用的内容渐显Tab选项卡效果
2015/09/07 Javascript
JavaScript编程学习技巧汇总
2016/02/21 Javascript
实例剖析AngularJS框架中数据的双向绑定运用
2016/03/04 Javascript
js的form表单提交url传参数(包含+等特殊字符)的两种解决方法
2016/05/25 Javascript
webpack入门+react环境配置
2017/02/08 Javascript
javascript 使用正则test( )第一次是 true,第二次是false
2017/02/22 Javascript
基于vue v-for 循环复选框-默认勾选第一个的实现方法
2018/03/03 Javascript
微信小程序提交form操作示例
2018/12/30 Javascript
基于vue.js实现购物车
2020/01/15 Javascript
[01:00:10]完美世界DOTA2联赛PWL S2 FTD vs Inki 第二场 11.21
2020/11/24 DOTA
python机器学习实战之树回归详解
2017/12/20 Python
Python实现读取SQLServer数据并插入到MongoDB数据库的方法示例
2018/06/09 Python
详解Python对JSON中的特殊类型进行Encoder
2019/07/15 Python
Django 查询数据库并返回页面的例子
2019/08/12 Python
详解Flask前后端分离项目案例
2020/07/24 Python
美国网上眼镜供应商:LEOTONY(眼镜、RX太阳镜和太阳镜)
2017/10/31 全球购物
90后毕业生的求职信范文
2013/09/21 职场文书
暑期研修感言
2014/02/17 职场文书
听课评语大全
2014/04/30 职场文书
班组长安全工作职责
2014/07/15 职场文书
小学教师学习党的群众路线教育实践活动心得体会
2014/10/31 职场文书
人事聘任通知
2015/04/21 职场文书
面试复试通知单
2015/04/24 职场文书
css display table 自适应高度、宽度问题的解决
2021/05/07 HTML / CSS
Redis缓存-序列化对象存储乱码问题的解决
2021/06/21 Redis