pytorch下使用LSTM神经网络写诗实例


Posted in Python onJanuary 14, 2020

在pytorch下,以数万首唐诗为素材,训练双层LSTM神经网络,使其能够以唐诗的方式写诗。

代码结构分为四部分,分别为

1.model.py,定义了双层LSTM模型

2.data.py,定义了从网上得到的唐诗数据的处理方法

3.utlis.py 定义了损失可视化的函数

4.main.py定义了模型参数,以及训练、唐诗生成函数。

参考:电子工业出版社的《深度学习框架PyTorch:入门与实践》第九章

main代码及注释如下

import sys, os
import torch as t
from data import get_data
from model import PoetryModel
from torch import nn
from torch.autograd import Variable
from utils import Visualizer
import tqdm
from torchnet import meter
import ipdb
 
class Config(object):
	data_path = 'data/'
	pickle_path = 'tang.npz'
	author = None
	constrain = None
	category = 'poet.tang' #or poet.song
	lr = 1e-3
	weight_decay = 1e-4
	use_gpu = True
	epoch = 20
	batch_size = 128
	maxlen = 125
	plot_every = 20
	#use_env = True #是否使用visodm
	env = 'poety' 
	#visdom env
	max_gen_len = 200
	debug_file = '/tmp/debugp'
	model_path = None
	prefix_words = '细雨鱼儿出,微风燕子斜。' 
	#不是诗歌组成部分,是意境
	start_words = '闲云潭影日悠悠' 
	#诗歌开始
	acrostic = False 
	#是否藏头
	model_prefix = 'checkpoints/tang' 
	#模型保存路径
opt = Config()
 
def generate(model, start_words, ix2word, word2ix, prefix_words=None):
	'''
	给定几个词,根据这几个词接着生成一首完整的诗歌
	'''
	results = list(start_words)
	start_word_len = len(start_words)
	# 手动设置第一个词为<START>
	# 这个地方有问题,最后需要再看一下
	input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
	if opt.use_gpu:input=input.cuda()
	hidden = None
	
	if prefix_words:
		for word in prefix_words:
			output,hidden = model(input,hidden)
			# 下边这句话是为了把input变成1*1?
			input = Variable(input.data.new([word2ix[word]])).view(1,1)
	for i in range(opt.max_gen_len):
		output,hidden = model(input,hidden)
		
		if i<start_word_len:
			w = results[i]
			input = Variable(input.data.new([word2ix[w]])).view(1,1)
		else:
			top_index = output.data[0].topk(1)[1][0]
			w = ix2word[top_index]
			results.append(w)
			input = Variable(input.data.new([top_index])).view(1,1)
		if w=='<EOP>':
			del results[-1] #-1的意思是倒数第一个
			break
	return results
 
def gen_acrostic(model,start_words,ix2word,word2ix, prefix_words = None):
 '''
 生成藏头诗
 start_words : u'深度学习'
 生成:
 深木通中岳,青苔半日脂。
 度山分地险,逆浪到南巴。
 学道兵犹毒,当时燕不移。
 习根通古岸,开镜出清羸。
 '''
 results = []
 start_word_len = len(start_words)
 input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
 if opt.use_gpu:input=input.cuda()
 hidden = None
 
 index=0 # 用来指示已经生成了多少句藏头诗
 # 上一个词
 pre_word='<START>'
 
 if prefix_words:
  for word in prefix_words:
   output,hidden = model(input,hidden)
   input = Variable(input.data.new([word2ix[word]])).view(1,1)
 
 for i in range(opt.max_gen_len):
  output,hidden = model(input,hidden)
  top_index = output.data[0].topk(1)[1][0]
  w = ix2word[top_index]
 
  if (pre_word in {u'。',u'!','<START>'} ):
   # 如果遇到句号,藏头的词送进去生成
 
   if index==start_word_len:
    # 如果生成的诗歌已经包含全部藏头的词,则结束
    break
   else: 
    # 把藏头的词作为输入送入模型
    w = start_words[index]
    index+=1
    input = Variable(input.data.new([word2ix[w]])).view(1,1) 
  else:
   # 否则的话,把上一次预测是词作为下一个词输入
   input = Variable(input.data.new([word2ix[w]])).view(1,1)
  results.append(w)
  pre_word = w
 return results
 
def train(**kwargs):
	
	for k,v in kwargs.items():
		setattr(opt,k,v) #设置apt里属性的值
	vis = Visualizer(env=opt.env)
	
	#获取数据
	data, word2ix, ix2word = get_data(opt) #get_data是data.py里的函数
	data = t.from_numpy(data)
	#这个地方出错了,是大写的L
	dataloader = t.utils.data.DataLoader(data, 
					batch_size = opt.batch_size,
					shuffle = True,
					num_workers = 1) #在python里,这样写程序可以吗?
 #模型定义
	model = PoetryModel(len(word2ix), 128, 256)
	optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
	criterion = nn.CrossEntropyLoss()
 
	if opt.model_path:
		model.load_state_dict(t.load(opt.model_path))
	if opt.use_gpu:
		model.cuda()
		criterion.cuda()
		
	#The tnt.AverageValueMeter measures and returns the average value 
	#and the standard deviation of any collection of numbers that are 
	#added to it. It is useful, for instance, to measure the average 
	#loss over a collection of examples.
 
 #The add() function expects as input a Lua number value, which 
 #is the value that needs to be added to the list of values to 
 #average. It also takes as input an optional parameter n that 
 #assigns a weight to value in the average, in order to facilitate 
 #computing weighted averages (default = 1).
 
 #The tnt.AverageValueMeter has no parameters to be set at initialization time. 
	loss_meter = meter.AverageValueMeter()
	
	for epoch in range(opt.epoch):
		loss_meter.reset()
		for ii,data_ in tqdm.tqdm(enumerate(dataloader)):
			#tqdm是python中的进度条
			#训练
			data_ = data_.long().transpose(1,0).contiguous()
			#上边一句话,把data_变成long类型,把1维和0维转置,把内存调成连续的
			if opt.use_gpu: data_ = data_.cuda()
			optimizer.zero_grad()
			input_, target = Variable(data_[:-1,:]), Variable(data_[1:,:])
			#上边一句,将输入的诗句错开一个字,形成训练和目标
			output,_ = model(input_)
			loss = criterion(output, target.view(-1))
			loss.backward()
			optimizer.step()
			
			loss_meter.add(loss.data[0]) #为什么是data[0]?
			
			#可视化用到的是utlis.py里的函数
			if (1+ii)%opt.plot_every ==0:
				
				if os.path.exists(opt.debug_file):
					ipdb.set_trace()
				vis.plot('loss',loss_meter.value()[0])
				
				# 下面是对目前模型情况的测试,诗歌原文
				poetrys = [[ix2word[_word] for _word in data_[:,_iii]] 
									for _iii in range(data_.size(1))][:16]
				#上面句子嵌套了两个循环,主要是将诗歌索引的前十六个字变成原文
				vis.text('</br>'.join([''.join(poetry) for poetry in 
				poetrys]),win = u'origin_poem')
				gen_poetries = []
				#分别以以下几个字作为诗歌的第一个字,生成8首诗
				for word in list(u'春江花月夜凉如水'):
					gen_poetry = ''.join(generate(model,word,ix2word,word2ix))
					gen_poetries.append(gen_poetry)
				vis.text('</br>'.join([''.join(poetry) for poetry in 
				gen_poetries]), win = u'gen_poem')
		t.save(model.state_dict(), '%s_%s.pth' %(opt.model_prefix,epoch))
 
def gen(**kwargs):
	'''
	提供命令行接口,用以生成相应的诗
	'''
	
	for k,v in kwargs.items():
		setattr(opt,k,v)
	data, word2ix, ix2word = get_data(opt)
	model = PoetryModel(len(word2ix), 128, 256)
	map_location = lambda s,l:s
	# 上边句子里的map_location是在load里用的,用以加载到指定的CPU或GPU,
	# 上边句子的意思是将模型加载到默认的GPU上
	state_dict = t.load(opt.model_path, map_location = map_location)
	model.load_state_dict(state_dict)
	
	if opt.use_gpu:
		model.cuda()
	if sys.version_info.major == 3:
		if opt.start_words.insprintable():
			start_words = opt.start_words
			prefix_words = opt.prefix_words if opt.prefix_words else None
		else:
			start_words = opt.start_words.encode('ascii',\
			'surrogateescape').decode('utf8')
			prefix_words = opt.prefix_words.encode('ascii',\
			'surrogateescape').decode('utf8') if opt.prefix_words else None
		start_words = start_words.replace(',',u',')\
											.replace('.',u'。')\
											.replace('?',u'?')
		gen_poetry = gen_acrostic if opt.acrostic else generate
		result = gen_poetry(model,start_words,ix2word,word2ix,prefix_words)
		print(''.join(result))
if __name__ == '__main__':
	import fire
	fire.Fire()

以上代码给我一些经验,

1. 了解python的编程方式,如空格、换行等;进一步了解python的各个基本模块;

2. 可能出的错误:函数名写错,大小写,变量名写错,括号不全。

3. 对cuda()的用法有了进一步认识;

4. 学会了调试程序(fire);

5. 学会了训练结果的可视化(visdom);

6. 进一步的了解了LSTM,对深度学习的架构、实现有了宏观把控。

这篇pytorch下使用LSTM神经网络写诗实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中用format函数格式化字符串的用法
Apr 08 Python
python使用xlrd与xlwt对excel的读写和格式设定
Jan 21 Python
Python实现调度算法代码详解
Dec 01 Python
python中map的基本用法示例
Sep 10 Python
python随机在一张图像上截取任意大小图片的方法
Jan 24 Python
python openCV获取人脸部分并存储功能
Aug 28 Python
Python 实现日志同时输出到屏幕和文件
Feb 19 Python
Python matplotlib绘制图形实例(包括点,曲线,注释和箭头)
Apr 17 Python
Python多线程:主线程等待所有子线程结束代码
Apr 25 Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 Python
python脚本和网页有何区别
Jul 02 Python
python中pow函数用法及功能说明
Dec 04 Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
PyTorch实现ResNet50、ResNet101和ResNet152示例
Jan 14 #Python
python重要函数eval多种用法解析
Jan 14 #Python
You might like
PHP-CGI进程CPU 100% 与 file_get_contents 函数的关系分析
2011/08/15 PHP
php模拟js函数unescape的函数代码
2012/10/20 PHP
PHP将进程作为守护进程的方法
2015/03/19 PHP
php实现从上传文件创建缩略图的方法
2015/04/02 PHP
PHP+MySql实现一个简单的留言板
2020/07/19 PHP
javascript知识点收藏
2007/02/22 Javascript
jquery tools 系列 scrollable(2)
2009/09/06 Javascript
Js 中debug方式
2010/02/07 Javascript
你需要知道的10个最佳javascript开发实践小结
2012/04/15 Javascript
JS实现的Select三级下拉菜单代码
2015/08/20 Javascript
js读取并解析JSON类型数据的方法
2015/11/14 Javascript
jQuery Easyui datagrid/treegrid 清空数据
2016/07/09 Javascript
浅谈对Angular中的生命周期钩子的理解
2017/07/31 Javascript
Vue仿支付宝支付功能
2018/05/25 Javascript
vue监听对象及对象属性问题
2018/08/20 Javascript
JavaScript原型对象原理与应用分析
2018/12/27 Javascript
详解Vue.js和layui日期控件冲突问题解决办法
2019/07/25 Javascript
最基础的Python的socket编程入门教程
2015/04/23 Python
python实现bucket排序算法实例分析
2015/05/04 Python
常见python正则用法的简单实例
2016/06/21 Python
详解windows python3.7安装numpy问题的解决方法
2018/08/13 Python
python自动分箱,计算woe,iv的实例代码
2019/11/22 Python
使用python实现CGI环境搭建过程解析
2020/04/28 Python
Python+Opencv身份证号码区域提取及识别实现
2020/08/25 Python
python3.7 openpyxl 在excel单元格中写入数据实例
2020/09/01 Python
美国真皮手袋品牌:GiGi New York
2017/03/10 全球购物
尼克松手表官网:Nixon手表
2019/03/17 全球购物
Quiksilver美国官网:始于1969年的优质冲浪服和滑雪板外套
2020/04/20 全球购物
Weblogic的布署方式
2013/08/23 面试题
高中同学聚会邀请函
2014/01/11 职场文书
汽车队司机先进事迹材料
2014/02/01 职场文书
中专生自我鉴定范文
2014/02/02 职场文书
领导接待方案
2014/03/13 职场文书
成绩单公证书
2014/04/10 职场文书
小学教师个人工作总结2015
2015/04/20 职场文书
2015年高中班主任工作总结
2015/04/30 职场文书