pytorch GAN生成对抗网络实例


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(1)
np.random.seed(1)

BATCH_SIZE = 64
LR_G = 0.0001
LR_D = 0.0001
N_IDEAS = 5
ART_COMPONENTS = 15
PAINT_POINTS = np.vstack([np.linspace(-1,1,ART_COMPONENTS) for _ in range(BATCH_SIZE)])

def artist_works():
	a = np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]
	paintings = a*np.power(PAINT_POINTS,2) + (a-1)
	paintings = torch.from_numpy(paintings).float()
	return Variable(paintings)

G = nn.Sequential(
	nn.Linear(N_IDEAS,128),
	nn.ReLU(),
	nn.Linear(128,ART_COMPONENTS),
)

D = nn.Sequential(
	nn.Linear(ART_COMPONENTS,128),
	nn.ReLU(),
	nn.Linear(128,1),
	nn.Sigmoid(),
)

opt_D = torch.optim.Adam(D.parameters(),lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(),lr=LR_G)

plt.ion()

for step in range(10000):
	artist_paintings = artist_works()
	G_ideas = Variable(torch.randn(BATCH_SIZE,N_IDEAS))
	G_paintings = G(G_ideas)

	prob_artist0 = D(artist_paintings)
	prob_artist1 = D(G_paintings)

	D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1-prob_artist1))
	G_loss = torch.mean(torch.log(1 - prob_artist1))

	opt_D.zero_grad()
	D_loss.backward(retain_variables=True)
	opt_D.step()

	opt_G.zero_grad()
	G_loss.backward()
	opt_G.step()

	if step % 50 == 0:
		plt.cla()
		plt.plot(PAINT_POINTS[0],G_paintings.data.numpy()[0],c='#4ad631',lw=3,label='Generated painting',)
		plt.plot(PAINT_POINTS[0],2 * np.power(PAINT_POINTS[0], 2) + 1,c='#74BCFF',lw=3,label='upper bound',)
		plt.plot(PAINT_POINTS[0],1 * np.power(PAINT_POINTS[0], 2) + 0,c='#FF9359',lw=3,label='lower bound',)
		plt.text(-.5,2.3,'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size':15})
		plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 15})
		plt.ylim((0,3))
		plt.legend(loc='upper right', fontsize=12)
		plt.draw()
		plt.pause(0.01)

plt.ioff()
plt.show()

pytorch GAN生成对抗网络实例

以上这篇pytorch GAN生成对抗网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
零基础写python爬虫之打包生成exe文件
Nov 06 Python
Python中操作MySQL入门实例
Feb 08 Python
由Python运算π的值深入Python中科学计算的实现
Apr 17 Python
Python实现简单的代理服务器
Jul 25 Python
Python基础篇之初识Python必看攻略
Jun 23 Python
python在Windows下安装setuptools(easy_install工具)步骤详解
Jul 01 Python
python实现超简单的视频对象提取功能
Jun 04 Python
python使用thrift教程的方法示例
Mar 21 Python
python使用 zip 同时迭代多个序列示例
Jul 06 Python
python3实现elasticsearch批量更新数据
Dec 03 Python
jupyter使用自动补全和切换默认浏览器的方法
Nov 18 Python
Python+Tkinter制作专属图形化界面
Apr 01 Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
python-OpenCV 实现将数组转换成灰度图和彩图
Jan 09 #Python
You might like
浅析php插件 HTMLPurifier HTML解析器
2013/07/01 PHP
php fsockopen解决办法 php实现多线程
2014/01/20 PHP
详解PHP错误日志的获取方法
2015/07/20 PHP
PHP 模拟登陆功能实例详解
2019/09/10 PHP
js移除事件 js绑定事件实例应用
2012/11/28 Javascript
探讨js中的双感叹号判断
2013/11/11 Javascript
JS获取单击按钮单元格所在行的信息
2014/06/17 Javascript
深入理解JavaScript系列(31):设计模式之代理模式详解
2015/03/03 Javascript
Knockout结合Bootstrap创建动态UI实现产品列表管理
2016/09/14 Javascript
微信小程序-详解数据缓存
2016/11/24 Javascript
vue2.0开发实践总结之入门篇
2016/12/06 Javascript
smartupload实现文件上传时获取表单数据(推荐)
2016/12/12 Javascript
详解plotly.js 绘图库入门使用教程
2018/02/23 Javascript
npm 下载指定版本的组件方法
2018/05/17 Javascript
jQuery实现获取及设置CSS样式操作详解
2018/09/05 jQuery
教你如何用Node实现API的转发(某音乐)
2019/09/20 Javascript
浅析Vue 防抖与节流的使用
2019/11/14 Javascript
json_decode 索引为数字时自动排序问题解决方法
2020/03/28 Javascript
[26:52]LGD vs EG 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/18 DOTA
[41:52]DOTA2-DPC中国联赛 正赛 CDEC vs Dynasty BO3 第二场 2月22日
2021/03/11 DOTA
用ReactJS和Python的Flask框架编写留言板的代码示例
2015/12/19 Python
Python自动发邮件脚本
2017/03/31 Python
Python图片裁剪实例代码(如头像裁剪)
2017/06/21 Python
Python2 Selenium元素定位的实现(8种)
2019/02/25 Python
浅谈python中统计计数的几种方法和Counter详解
2019/11/07 Python
解决python replace函数替换无效问题
2020/01/18 Python
Numpy ndarray 多维数组对象的使用
2021/02/10 Python
英国羊绒服装购物网站:Pure Collection
2018/10/22 全球购物
求职简历推荐信范文
2013/12/02 职场文书
自我评价个人范文
2013/12/16 职场文书
大学同学聚会邀请函
2014/01/29 职场文书
《寓言两则》教学反思
2014/02/27 职场文书
学习型党组织建设经验材料
2014/05/26 职场文书
小学清明节活动总结
2014/07/04 职场文书
2016孝老爱亲模范事迹材料
2016/02/26 职场文书
TV动画《政宗君的复仇》第二季制作决定PV公布
2022/04/02 日漫