pytorch下使用LSTM神经网络写诗实例


Posted in Python onJanuary 14, 2020

在pytorch下,以数万首唐诗为素材,训练双层LSTM神经网络,使其能够以唐诗的方式写诗。

代码结构分为四部分,分别为

1.model.py,定义了双层LSTM模型

2.data.py,定义了从网上得到的唐诗数据的处理方法

3.utlis.py 定义了损失可视化的函数

4.main.py定义了模型参数,以及训练、唐诗生成函数。

参考:电子工业出版社的《深度学习框架PyTorch:入门与实践》第九章

main代码及注释如下

import sys, os
import torch as t
from data import get_data
from model import PoetryModel
from torch import nn
from torch.autograd import Variable
from utils import Visualizer
import tqdm
from torchnet import meter
import ipdb
 
class Config(object):
	data_path = 'data/'
	pickle_path = 'tang.npz'
	author = None
	constrain = None
	category = 'poet.tang' #or poet.song
	lr = 1e-3
	weight_decay = 1e-4
	use_gpu = True
	epoch = 20
	batch_size = 128
	maxlen = 125
	plot_every = 20
	#use_env = True #是否使用visodm
	env = 'poety' 
	#visdom env
	max_gen_len = 200
	debug_file = '/tmp/debugp'
	model_path = None
	prefix_words = '细雨鱼儿出,微风燕子斜。' 
	#不是诗歌组成部分,是意境
	start_words = '闲云潭影日悠悠' 
	#诗歌开始
	acrostic = False 
	#是否藏头
	model_prefix = 'checkpoints/tang' 
	#模型保存路径
opt = Config()
 
def generate(model, start_words, ix2word, word2ix, prefix_words=None):
	'''
	给定几个词,根据这几个词接着生成一首完整的诗歌
	'''
	results = list(start_words)
	start_word_len = len(start_words)
	# 手动设置第一个词为<START>
	# 这个地方有问题,最后需要再看一下
	input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
	if opt.use_gpu:input=input.cuda()
	hidden = None
	
	if prefix_words:
		for word in prefix_words:
			output,hidden = model(input,hidden)
			# 下边这句话是为了把input变成1*1?
			input = Variable(input.data.new([word2ix[word]])).view(1,1)
	for i in range(opt.max_gen_len):
		output,hidden = model(input,hidden)
		
		if i<start_word_len:
			w = results[i]
			input = Variable(input.data.new([word2ix[w]])).view(1,1)
		else:
			top_index = output.data[0].topk(1)[1][0]
			w = ix2word[top_index]
			results.append(w)
			input = Variable(input.data.new([top_index])).view(1,1)
		if w=='<EOP>':
			del results[-1] #-1的意思是倒数第一个
			break
	return results
 
def gen_acrostic(model,start_words,ix2word,word2ix, prefix_words = None):
 '''
 生成藏头诗
 start_words : u'深度学习'
 生成:
 深木通中岳,青苔半日脂。
 度山分地险,逆浪到南巴。
 学道兵犹毒,当时燕不移。
 习根通古岸,开镜出清羸。
 '''
 results = []
 start_word_len = len(start_words)
 input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
 if opt.use_gpu:input=input.cuda()
 hidden = None
 
 index=0 # 用来指示已经生成了多少句藏头诗
 # 上一个词
 pre_word='<START>'
 
 if prefix_words:
  for word in prefix_words:
   output,hidden = model(input,hidden)
   input = Variable(input.data.new([word2ix[word]])).view(1,1)
 
 for i in range(opt.max_gen_len):
  output,hidden = model(input,hidden)
  top_index = output.data[0].topk(1)[1][0]
  w = ix2word[top_index]
 
  if (pre_word in {u'。',u'!','<START>'} ):
   # 如果遇到句号,藏头的词送进去生成
 
   if index==start_word_len:
    # 如果生成的诗歌已经包含全部藏头的词,则结束
    break
   else: 
    # 把藏头的词作为输入送入模型
    w = start_words[index]
    index+=1
    input = Variable(input.data.new([word2ix[w]])).view(1,1) 
  else:
   # 否则的话,把上一次预测是词作为下一个词输入
   input = Variable(input.data.new([word2ix[w]])).view(1,1)
  results.append(w)
  pre_word = w
 return results
 
def train(**kwargs):
	
	for k,v in kwargs.items():
		setattr(opt,k,v) #设置apt里属性的值
	vis = Visualizer(env=opt.env)
	
	#获取数据
	data, word2ix, ix2word = get_data(opt) #get_data是data.py里的函数
	data = t.from_numpy(data)
	#这个地方出错了,是大写的L
	dataloader = t.utils.data.DataLoader(data, 
					batch_size = opt.batch_size,
					shuffle = True,
					num_workers = 1) #在python里,这样写程序可以吗?
 #模型定义
	model = PoetryModel(len(word2ix), 128, 256)
	optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
	criterion = nn.CrossEntropyLoss()
 
	if opt.model_path:
		model.load_state_dict(t.load(opt.model_path))
	if opt.use_gpu:
		model.cuda()
		criterion.cuda()
		
	#The tnt.AverageValueMeter measures and returns the average value 
	#and the standard deviation of any collection of numbers that are 
	#added to it. It is useful, for instance, to measure the average 
	#loss over a collection of examples.
 
 #The add() function expects as input a Lua number value, which 
 #is the value that needs to be added to the list of values to 
 #average. It also takes as input an optional parameter n that 
 #assigns a weight to value in the average, in order to facilitate 
 #computing weighted averages (default = 1).
 
 #The tnt.AverageValueMeter has no parameters to be set at initialization time. 
	loss_meter = meter.AverageValueMeter()
	
	for epoch in range(opt.epoch):
		loss_meter.reset()
		for ii,data_ in tqdm.tqdm(enumerate(dataloader)):
			#tqdm是python中的进度条
			#训练
			data_ = data_.long().transpose(1,0).contiguous()
			#上边一句话,把data_变成long类型,把1维和0维转置,把内存调成连续的
			if opt.use_gpu: data_ = data_.cuda()
			optimizer.zero_grad()
			input_, target = Variable(data_[:-1,:]), Variable(data_[1:,:])
			#上边一句,将输入的诗句错开一个字,形成训练和目标
			output,_ = model(input_)
			loss = criterion(output, target.view(-1))
			loss.backward()
			optimizer.step()
			
			loss_meter.add(loss.data[0]) #为什么是data[0]?
			
			#可视化用到的是utlis.py里的函数
			if (1+ii)%opt.plot_every ==0:
				
				if os.path.exists(opt.debug_file):
					ipdb.set_trace()
				vis.plot('loss',loss_meter.value()[0])
				
				# 下面是对目前模型情况的测试,诗歌原文
				poetrys = [[ix2word[_word] for _word in data_[:,_iii]] 
									for _iii in range(data_.size(1))][:16]
				#上面句子嵌套了两个循环,主要是将诗歌索引的前十六个字变成原文
				vis.text('</br>'.join([''.join(poetry) for poetry in 
				poetrys]),win = u'origin_poem')
				gen_poetries = []
				#分别以以下几个字作为诗歌的第一个字,生成8首诗
				for word in list(u'春江花月夜凉如水'):
					gen_poetry = ''.join(generate(model,word,ix2word,word2ix))
					gen_poetries.append(gen_poetry)
				vis.text('</br>'.join([''.join(poetry) for poetry in 
				gen_poetries]), win = u'gen_poem')
		t.save(model.state_dict(), '%s_%s.pth' %(opt.model_prefix,epoch))
 
def gen(**kwargs):
	'''
	提供命令行接口,用以生成相应的诗
	'''
	
	for k,v in kwargs.items():
		setattr(opt,k,v)
	data, word2ix, ix2word = get_data(opt)
	model = PoetryModel(len(word2ix), 128, 256)
	map_location = lambda s,l:s
	# 上边句子里的map_location是在load里用的,用以加载到指定的CPU或GPU,
	# 上边句子的意思是将模型加载到默认的GPU上
	state_dict = t.load(opt.model_path, map_location = map_location)
	model.load_state_dict(state_dict)
	
	if opt.use_gpu:
		model.cuda()
	if sys.version_info.major == 3:
		if opt.start_words.insprintable():
			start_words = opt.start_words
			prefix_words = opt.prefix_words if opt.prefix_words else None
		else:
			start_words = opt.start_words.encode('ascii',\
			'surrogateescape').decode('utf8')
			prefix_words = opt.prefix_words.encode('ascii',\
			'surrogateescape').decode('utf8') if opt.prefix_words else None
		start_words = start_words.replace(',',u',')\
											.replace('.',u'。')\
											.replace('?',u'?')
		gen_poetry = gen_acrostic if opt.acrostic else generate
		result = gen_poetry(model,start_words,ix2word,word2ix,prefix_words)
		print(''.join(result))
if __name__ == '__main__':
	import fire
	fire.Fire()

以上代码给我一些经验,

1. 了解python的编程方式,如空格、换行等;进一步了解python的各个基本模块;

2. 可能出的错误:函数名写错,大小写,变量名写错,括号不全。

3. 对cuda()的用法有了进一步认识;

4. 学会了调试程序(fire);

5. 学会了训练结果的可视化(visdom);

6. 进一步的了解了LSTM,对深度学习的架构、实现有了宏观把控。

这篇pytorch下使用LSTM神经网络写诗实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 字符串大小写转换的简单实例
Jan 21 Python
Python实现数据可视化看如何监控你的爬虫状态【推荐】
Aug 10 Python
Python小工具之消耗系统指定大小内存的方法
Dec 03 Python
PyQt5创建一个新窗口的实例
Jun 20 Python
numpy中的meshgrid函数的使用
Jul 31 Python
pytorch numpy list类型之间的相互转换实例
Aug 18 Python
Python 获取指定文件夹下的目录和文件的实现
Aug 30 Python
python3实现用turtle模块画一棵随机樱花树
Nov 21 Python
python3读取csv文件任意行列代码实例
Jan 13 Python
使用matplotlib的pyplot模块绘图的实现示例
Jul 12 Python
python写文件时覆盖原来的实例方法
Jul 22 Python
python数据库批量插入数据的实现(executemany的使用)
Apr 30 Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
PyTorch实现ResNet50、ResNet101和ResNet152示例
Jan 14 #Python
python重要函数eval多种用法解析
Jan 14 #Python
You might like
PHP新手上路(十)
2006/10/09 PHP
深入Memcache的Session数据的多服务器共享详解
2013/06/13 PHP
thinkphp使用literal防止模板标签被解析的方法
2014/11/22 PHP
layui框架实现文件上传及TP3.2.3(thinkPHP)对上传文件进行后台处理操作示例
2018/05/12 PHP
PHP5.6.8连接SQL Server 2008 R2数据库常用技巧分析总结
2019/05/06 PHP
javascript的offset、client、scroll使用方法详解
2012/12/25 Javascript
JQuery slideshow的一个小问题(如何发现及解决过程)
2013/02/06 Javascript
完美兼容各大浏览器获取HTTP_REFERER方法总结
2014/06/24 Javascript
JavaScript中的ArrayBuffer详细介绍
2014/12/08 Javascript
最简单的JavaScript验证整数、小数、实数、有效位小数正则表达式
2015/04/17 Javascript
jQuery实现无限往下滚动效果代码
2016/04/16 Javascript
JavaScript操作表单实例讲解(上)
2016/06/20 Javascript
详解Bootstrap的iCheck插件checkbox和radio
2016/08/24 Javascript
基于JavaScript实现自定义滚动条
2017/01/25 Javascript
JS实现的找零张数最小问题示例
2017/11/28 Javascript
electron中使用bootstrap的示例代码
2018/11/06 Javascript
vue中beforeRouteLeave实现页面回退不刷新的示例代码
2019/11/01 Javascript
如何利用node.js开发一个生成逐帧动画的小工具
2019/12/01 Javascript
python切换hosts文件代码示例
2013/12/31 Python
Python实现检测服务器是否可以ping通的2种方法
2015/01/01 Python
简化Python的Django框架代码的一些示例
2015/04/20 Python
Python操作SQLite数据库的方法详解
2017/06/16 Python
python版学生管理系统
2018/01/10 Python
用Python下载一个网页保存为本地的HTML文件实例
2018/05/21 Python
python读取TXT每行,并存到LIST中的方法
2018/10/26 Python
python模拟登陆,用session维持回话的实例
2018/12/27 Python
python爬虫快速响应服务器的做法
2020/11/24 Python
Python爬虫进阶之爬取某视频并下载的实现
2020/12/08 Python
英国领先的男士服装和时尚零售商:Burton
2017/01/09 全球购物
Ticketmaster意大利:音乐会、节日、艺术和剧院的官方门票
2019/12/23 全球购物
应届生自荐书
2014/06/23 职场文书
班主任工作总结范文
2015/08/13 职场文书
有关花店创业的计划书模板
2019/08/27 职场文书
html5表单的required属性使用
2021/07/07 HTML / CSS
Python 阶乘详解
2021/10/05 Python
python多线程方法详解
2022/01/18 Python