pytorch下使用LSTM神经网络写诗实例


Posted in Python onJanuary 14, 2020

在pytorch下,以数万首唐诗为素材,训练双层LSTM神经网络,使其能够以唐诗的方式写诗。

代码结构分为四部分,分别为

1.model.py,定义了双层LSTM模型

2.data.py,定义了从网上得到的唐诗数据的处理方法

3.utlis.py 定义了损失可视化的函数

4.main.py定义了模型参数,以及训练、唐诗生成函数。

参考:电子工业出版社的《深度学习框架PyTorch:入门与实践》第九章

main代码及注释如下

import sys, os
import torch as t
from data import get_data
from model import PoetryModel
from torch import nn
from torch.autograd import Variable
from utils import Visualizer
import tqdm
from torchnet import meter
import ipdb
 
class Config(object):
	data_path = 'data/'
	pickle_path = 'tang.npz'
	author = None
	constrain = None
	category = 'poet.tang' #or poet.song
	lr = 1e-3
	weight_decay = 1e-4
	use_gpu = True
	epoch = 20
	batch_size = 128
	maxlen = 125
	plot_every = 20
	#use_env = True #是否使用visodm
	env = 'poety' 
	#visdom env
	max_gen_len = 200
	debug_file = '/tmp/debugp'
	model_path = None
	prefix_words = '细雨鱼儿出,微风燕子斜。' 
	#不是诗歌组成部分,是意境
	start_words = '闲云潭影日悠悠' 
	#诗歌开始
	acrostic = False 
	#是否藏头
	model_prefix = 'checkpoints/tang' 
	#模型保存路径
opt = Config()
 
def generate(model, start_words, ix2word, word2ix, prefix_words=None):
	'''
	给定几个词,根据这几个词接着生成一首完整的诗歌
	'''
	results = list(start_words)
	start_word_len = len(start_words)
	# 手动设置第一个词为<START>
	# 这个地方有问题,最后需要再看一下
	input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
	if opt.use_gpu:input=input.cuda()
	hidden = None
	
	if prefix_words:
		for word in prefix_words:
			output,hidden = model(input,hidden)
			# 下边这句话是为了把input变成1*1?
			input = Variable(input.data.new([word2ix[word]])).view(1,1)
	for i in range(opt.max_gen_len):
		output,hidden = model(input,hidden)
		
		if i<start_word_len:
			w = results[i]
			input = Variable(input.data.new([word2ix[w]])).view(1,1)
		else:
			top_index = output.data[0].topk(1)[1][0]
			w = ix2word[top_index]
			results.append(w)
			input = Variable(input.data.new([top_index])).view(1,1)
		if w=='<EOP>':
			del results[-1] #-1的意思是倒数第一个
			break
	return results
 
def gen_acrostic(model,start_words,ix2word,word2ix, prefix_words = None):
 '''
 生成藏头诗
 start_words : u'深度学习'
 生成:
 深木通中岳,青苔半日脂。
 度山分地险,逆浪到南巴。
 学道兵犹毒,当时燕不移。
 习根通古岸,开镜出清羸。
 '''
 results = []
 start_word_len = len(start_words)
 input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
 if opt.use_gpu:input=input.cuda()
 hidden = None
 
 index=0 # 用来指示已经生成了多少句藏头诗
 # 上一个词
 pre_word='<START>'
 
 if prefix_words:
  for word in prefix_words:
   output,hidden = model(input,hidden)
   input = Variable(input.data.new([word2ix[word]])).view(1,1)
 
 for i in range(opt.max_gen_len):
  output,hidden = model(input,hidden)
  top_index = output.data[0].topk(1)[1][0]
  w = ix2word[top_index]
 
  if (pre_word in {u'。',u'!','<START>'} ):
   # 如果遇到句号,藏头的词送进去生成
 
   if index==start_word_len:
    # 如果生成的诗歌已经包含全部藏头的词,则结束
    break
   else: 
    # 把藏头的词作为输入送入模型
    w = start_words[index]
    index+=1
    input = Variable(input.data.new([word2ix[w]])).view(1,1) 
  else:
   # 否则的话,把上一次预测是词作为下一个词输入
   input = Variable(input.data.new([word2ix[w]])).view(1,1)
  results.append(w)
  pre_word = w
 return results
 
def train(**kwargs):
	
	for k,v in kwargs.items():
		setattr(opt,k,v) #设置apt里属性的值
	vis = Visualizer(env=opt.env)
	
	#获取数据
	data, word2ix, ix2word = get_data(opt) #get_data是data.py里的函数
	data = t.from_numpy(data)
	#这个地方出错了,是大写的L
	dataloader = t.utils.data.DataLoader(data, 
					batch_size = opt.batch_size,
					shuffle = True,
					num_workers = 1) #在python里,这样写程序可以吗?
 #模型定义
	model = PoetryModel(len(word2ix), 128, 256)
	optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
	criterion = nn.CrossEntropyLoss()
 
	if opt.model_path:
		model.load_state_dict(t.load(opt.model_path))
	if opt.use_gpu:
		model.cuda()
		criterion.cuda()
		
	#The tnt.AverageValueMeter measures and returns the average value 
	#and the standard deviation of any collection of numbers that are 
	#added to it. It is useful, for instance, to measure the average 
	#loss over a collection of examples.
 
 #The add() function expects as input a Lua number value, which 
 #is the value that needs to be added to the list of values to 
 #average. It also takes as input an optional parameter n that 
 #assigns a weight to value in the average, in order to facilitate 
 #computing weighted averages (default = 1).
 
 #The tnt.AverageValueMeter has no parameters to be set at initialization time. 
	loss_meter = meter.AverageValueMeter()
	
	for epoch in range(opt.epoch):
		loss_meter.reset()
		for ii,data_ in tqdm.tqdm(enumerate(dataloader)):
			#tqdm是python中的进度条
			#训练
			data_ = data_.long().transpose(1,0).contiguous()
			#上边一句话,把data_变成long类型,把1维和0维转置,把内存调成连续的
			if opt.use_gpu: data_ = data_.cuda()
			optimizer.zero_grad()
			input_, target = Variable(data_[:-1,:]), Variable(data_[1:,:])
			#上边一句,将输入的诗句错开一个字,形成训练和目标
			output,_ = model(input_)
			loss = criterion(output, target.view(-1))
			loss.backward()
			optimizer.step()
			
			loss_meter.add(loss.data[0]) #为什么是data[0]?
			
			#可视化用到的是utlis.py里的函数
			if (1+ii)%opt.plot_every ==0:
				
				if os.path.exists(opt.debug_file):
					ipdb.set_trace()
				vis.plot('loss',loss_meter.value()[0])
				
				# 下面是对目前模型情况的测试,诗歌原文
				poetrys = [[ix2word[_word] for _word in data_[:,_iii]] 
									for _iii in range(data_.size(1))][:16]
				#上面句子嵌套了两个循环,主要是将诗歌索引的前十六个字变成原文
				vis.text('</br>'.join([''.join(poetry) for poetry in 
				poetrys]),win = u'origin_poem')
				gen_poetries = []
				#分别以以下几个字作为诗歌的第一个字,生成8首诗
				for word in list(u'春江花月夜凉如水'):
					gen_poetry = ''.join(generate(model,word,ix2word,word2ix))
					gen_poetries.append(gen_poetry)
				vis.text('</br>'.join([''.join(poetry) for poetry in 
				gen_poetries]), win = u'gen_poem')
		t.save(model.state_dict(), '%s_%s.pth' %(opt.model_prefix,epoch))
 
def gen(**kwargs):
	'''
	提供命令行接口,用以生成相应的诗
	'''
	
	for k,v in kwargs.items():
		setattr(opt,k,v)
	data, word2ix, ix2word = get_data(opt)
	model = PoetryModel(len(word2ix), 128, 256)
	map_location = lambda s,l:s
	# 上边句子里的map_location是在load里用的,用以加载到指定的CPU或GPU,
	# 上边句子的意思是将模型加载到默认的GPU上
	state_dict = t.load(opt.model_path, map_location = map_location)
	model.load_state_dict(state_dict)
	
	if opt.use_gpu:
		model.cuda()
	if sys.version_info.major == 3:
		if opt.start_words.insprintable():
			start_words = opt.start_words
			prefix_words = opt.prefix_words if opt.prefix_words else None
		else:
			start_words = opt.start_words.encode('ascii',\
			'surrogateescape').decode('utf8')
			prefix_words = opt.prefix_words.encode('ascii',\
			'surrogateescape').decode('utf8') if opt.prefix_words else None
		start_words = start_words.replace(',',u',')\
											.replace('.',u'。')\
											.replace('?',u'?')
		gen_poetry = gen_acrostic if opt.acrostic else generate
		result = gen_poetry(model,start_words,ix2word,word2ix,prefix_words)
		print(''.join(result))
if __name__ == '__main__':
	import fire
	fire.Fire()

以上代码给我一些经验,

1. 了解python的编程方式,如空格、换行等;进一步了解python的各个基本模块;

2. 可能出的错误:函数名写错,大小写,变量名写错,括号不全。

3. 对cuda()的用法有了进一步认识;

4. 学会了调试程序(fire);

5. 学会了训练结果的可视化(visdom);

6. 进一步的了解了LSTM,对深度学习的架构、实现有了宏观把控。

这篇pytorch下使用LSTM神经网络写诗实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现解析Bit Torrent种子文件内容的方法
Aug 29 Python
django项目运行因中文而乱码报错的几种情况解决
Nov 07 Python
python排序函数sort()与sorted()的区别
Sep 18 Python
详解Python 正则表达式模块
Nov 05 Python
Python3内置模块random随机方法小结
Jul 13 Python
python Gunicorn服务器使用方法详解
Jul 22 Python
Pandas DataFrame中的tuple元素遍历的实现
Oct 23 Python
python自动脚本的pyautogui入门学习
Apr 01 Python
python 使用elasticsearch 实现翻页的三种方式
Jul 31 Python
python实现自动打卡的示例代码
Oct 10 Python
python打包生成so文件的实现
Oct 30 Python
tensorboard 可视化之localhost:6006不显示的解决方案
May 22 Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
PyTorch实现ResNet50、ResNet101和ResNet152示例
Jan 14 #Python
python重要函数eval多种用法解析
Jan 14 #Python
You might like
CI框架安全类Security.php源码分析
2014/11/04 PHP
javascript+xml技术实现分页浏览
2008/07/27 Javascript
日期 时间js控件
2009/05/07 Javascript
基于jquery的让页面控件不可用的实现代码
2010/04/27 Javascript
jQuery帮助之筛选查找 children([expr])
2011/01/31 Javascript
js模拟点击以提交表单为例兼容主流浏览器
2013/11/29 Javascript
javascript定义变量时有var和没有var的区别探讨
2014/07/21 Javascript
JavaScript在浏览器标题栏上显示当前日期和时间的方法
2015/03/19 Javascript
JavaScript自定义等待wait函数实例分析
2015/03/23 Javascript
JavaScript中指定函数名称的相关方法
2015/06/04 Javascript
jQuery simpleModal插件的使用介绍
2016/08/30 Javascript
js中利用cookie实现记住密码功能
2020/08/20 Javascript
jQuery插件zTree实现的基本树与节点获取操作示例
2017/03/08 Javascript
Vue中render函数的使用方法
2018/01/31 Javascript
vue构建动态表单的方法示例
2018/09/22 Javascript
微信小程序登录按钮遮罩浮层效果的实现方法
2018/12/16 Javascript
vue实现全匹配搜索列表内容
2019/09/26 Javascript
微信小程序保持session会话的方法
2020/03/20 Javascript
[02:27]2018DOTA2亚洲邀请赛赛前采访-OpTic
2018/04/03 DOTA
[43:35]EG vs Winstrike 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
Python数据结构之Array用法实例
2014/10/09 Python
python创建列表和向列表添加元素的实现方法
2017/12/25 Python
python的dataframe转换为多维矩阵的方法
2018/04/11 Python
H5仿微信界面教程(一)
2017/07/05 HTML / CSS
Dr. Martens马汀博士官网:马丁靴始祖品牌
2016/10/15 全球购物
UNIONBAY官网:美国青少年服装品牌
2019/03/26 全球购物
搬家公司的创业计划书
2014/01/01 职场文书
元旦晚会上单位领导演讲稿
2014/01/05 职场文书
大学生通用个人的自我评价
2014/02/10 职场文书
小学生开学第一课活动方案
2014/03/27 职场文书
授权委托书范本
2014/04/03 职场文书
维修工先进事迹
2014/05/29 职场文书
奶茶店创业计划书
2014/08/14 职场文书
2014年学习委员工作总结
2014/11/14 职场文书
酒店宣传语大全
2015/07/13 职场文书
JVM的类加载器和双亲委派模式你了解吗
2022/03/13 Java/Android