pytorch GAN生成对抗网络实例


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(1)
np.random.seed(1)

BATCH_SIZE = 64
LR_G = 0.0001
LR_D = 0.0001
N_IDEAS = 5
ART_COMPONENTS = 15
PAINT_POINTS = np.vstack([np.linspace(-1,1,ART_COMPONENTS) for _ in range(BATCH_SIZE)])

def artist_works():
	a = np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]
	paintings = a*np.power(PAINT_POINTS,2) + (a-1)
	paintings = torch.from_numpy(paintings).float()
	return Variable(paintings)

G = nn.Sequential(
	nn.Linear(N_IDEAS,128),
	nn.ReLU(),
	nn.Linear(128,ART_COMPONENTS),
)

D = nn.Sequential(
	nn.Linear(ART_COMPONENTS,128),
	nn.ReLU(),
	nn.Linear(128,1),
	nn.Sigmoid(),
)

opt_D = torch.optim.Adam(D.parameters(),lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(),lr=LR_G)

plt.ion()

for step in range(10000):
	artist_paintings = artist_works()
	G_ideas = Variable(torch.randn(BATCH_SIZE,N_IDEAS))
	G_paintings = G(G_ideas)

	prob_artist0 = D(artist_paintings)
	prob_artist1 = D(G_paintings)

	D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1-prob_artist1))
	G_loss = torch.mean(torch.log(1 - prob_artist1))

	opt_D.zero_grad()
	D_loss.backward(retain_variables=True)
	opt_D.step()

	opt_G.zero_grad()
	G_loss.backward()
	opt_G.step()

	if step % 50 == 0:
		plt.cla()
		plt.plot(PAINT_POINTS[0],G_paintings.data.numpy()[0],c='#4ad631',lw=3,label='Generated painting',)
		plt.plot(PAINT_POINTS[0],2 * np.power(PAINT_POINTS[0], 2) + 1,c='#74BCFF',lw=3,label='upper bound',)
		plt.plot(PAINT_POINTS[0],1 * np.power(PAINT_POINTS[0], 2) + 0,c='#FF9359',lw=3,label='lower bound',)
		plt.text(-.5,2.3,'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size':15})
		plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 15})
		plt.ylim((0,3))
		plt.legend(loc='upper right', fontsize=12)
		plt.draw()
		plt.pause(0.01)

plt.ioff()
plt.show()

pytorch GAN生成对抗网络实例

以上这篇pytorch GAN生成对抗网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中type的构造函数参数含义说明
Jun 21 Python
详解Python中heapq模块的用法
Jun 28 Python
Python爬虫包 BeautifulSoup  递归抓取实例详解
Jan 28 Python
多版本Python共存的配置方法
May 22 Python
Python3.遍历某文件夹提取特定文件名的实例
Apr 26 Python
从DataFrame中提取出Series或DataFrame对象的方法
Nov 10 Python
python最小生成树kruskal与prim算法详解
Jan 17 Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 Python
python读取指定字节长度的文本方法
Aug 27 Python
Python 3.8 新功能大揭秘【新手必学】
Feb 05 Python
Pycharm激活码激活两种快速方式(附最新激活码和插件)
Mar 12 Python
python如何变换环境
Jul 21 Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
python-OpenCV 实现将数组转换成灰度图和彩图
Jan 09 #Python
You might like
php 论坛采集程序 模拟登陆,抓取页面 实现代码
2009/07/09 PHP
基于laravel制作APP接口(API)
2016/03/15 PHP
PHP使用反射机制实现查找类和方法的所在位置
2016/04/22 PHP
Jquery Ajax学习实例6 向WebService发出请求,返回DataSet(XML) 异步调用
2010/03/18 Javascript
xheditor与validate插件冲突的解决方案
2010/04/15 Javascript
javascript与CSS复习(二)
2010/06/29 Javascript
Extjs4 Treegrid 使用心得分享(经验篇)
2013/07/01 Javascript
详谈jQuery中的this和$(this)
2014/11/13 Javascript
nodejs 提示‘xxx’ 不是内部或外部命令解决方法
2014/11/20 NodeJs
jQuery异步上传文件插件ajaxFileUpload详细介绍
2015/05/19 Javascript
原生js模拟淘宝购物车项目实战
2015/11/18 Javascript
实例解析jQuery插件EasyUI最常用的表单验证规则
2015/11/29 Javascript
详解JavaScript表单验证(E-mail 验证)
2016/03/31 Javascript
详解React开发中使用require.ensure()按需加载ES6组件
2017/05/12 Javascript
jQuery简单绑定单个事件的方法示例
2017/06/10 jQuery
select标签设置默认选中的选项方法
2018/03/02 Javascript
在angular 6中使用 less 的实例代码
2018/05/13 Javascript
vue左右侧联动滚动的实现代码
2018/06/06 Javascript
react-router4按需加载(踩坑填坑)
2019/01/06 Javascript
JS实现的获取银行卡号归属地及银行卡类型操作示例
2019/01/08 Javascript
详解vue为什么要求组件模板只能有一个根元素
2019/07/22 Javascript
vue 自定指令生成uuid滚动监听达到tab表格吸顶效果的代码
2020/09/16 Javascript
python正常时间和unix时间戳相互转换的方法
2015/04/23 Python
使用Python保存网页上的图片或者保存页面为截图
2016/03/05 Python
Python机器学习之SVM支持向量机
2017/12/27 Python
在Python中实现函数重载的示例代码
2019/12/12 Python
python 工具 字符串转numpy浮点数组的实现
2020/03/14 Python
python小程序基于Jupyter实现天气查询的方法
2020/03/27 Python
Python中socket网络通信是干嘛的
2020/05/27 Python
python PIL模块的基本使用
2020/09/29 Python
使用PyCharm官方中文语言包汉化PyCharm
2020/11/18 Python
HTML5 微格式和相关的属性名称
2010/02/10 HTML / CSS
新闻学毕业生自荐信
2013/11/15 职场文书
合同专员岗位职责
2013/12/18 职场文书
篮球友谊赛通讯稿
2014/10/10 职场文书
统计工作个人总结
2015/03/03 职场文书