Keras自动下载的数据集/模型存放位置介绍


Posted in Python onJune 19, 2020

Mac

# 数据集
~/.keras/datasets/

# 模型
~/.keras/models/

Linux

# 数据集
~/.keras/datasets/

Windows

# win10
C:\Users\user_name\.keras\datasets

补充知识:Keras_gan生成自己的数据,并保存模型

我就废话不多说了,大家还是直接看代码吧~

from __future__ import print_function, division
 
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import os
import matplotlib.pyplot as plt
import sys
import numpy as np
 
class GAN():
 def __init__(self):
 self.img_rows = 3
 self.img_cols = 60
 self.channels = 1
 self.img_shape = (self.img_rows, self.img_cols, self.channels)
 self.latent_dim = 100
 
 optimizer = Adam(0.0002, 0.5)
 
 # 构建和编译判别器
 self.discriminator = self.build_discriminator()
 self.discriminator.compile(loss='binary_crossentropy',
  optimizer=optimizer,
  metrics=['accuracy'])
 
 # 构建生成器
 self.generator = self.build_generator()
 
 # 生成器输入噪音,生成假的图片
 z = Input(shape=(self.latent_dim,))
 img = self.generator(z)
 
 # 为了组合模型,只训练生成器
 self.discriminator.trainable = False
 
 # 判别器将生成的图像作为输入并确定有效性
 validity = self.discriminator(img)
 
 # The combined model (stacked generator and discriminator)
 # 训练生成器骗过判别器
 self.combined = Model(z, validity)
 self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
 
 def build_generator(self):
 
 model = Sequential()
 model.add(Dense(64, input_dim=self.latent_dim))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(128))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(256))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(512))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(1024))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 #np.prod(self.img_shape)=3x60x1
 model.add(Dense(np.prod(self.img_shape), activation='tanh'))
 model.add(Reshape(self.img_shape))
 
 model.summary()
 
 noise = Input(shape=(self.latent_dim,))
 img = model(noise)
 
 #输入噪音,输出图片
 return Model(noise, img)
 
 def build_discriminator(self):
 
 model = Sequential()
 
 model.add(Flatten(input_shape=self.img_shape))
 
 model.add(Dense(1024))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(512))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(256))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(128))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(64))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(1, activation='sigmoid'))
 model.summary()
 
 img = Input(shape=self.img_shape)
 validity = model(img)
 return Model(img, validity)
 
 def train(self, epochs, batch_size=128, sample_interval=50):
 
 ############################################################
 #自己数据集此部分需要更改
 # 加载数据集
 data = np.load('data/相对大小分叉.npy') 
 data = data[:,:,0:60]
 # 归一化到-1到1
 data = data * 2 - 1
 data = np.expand_dims(data, axis=3)
 ############################################################
 
 # Adversarial ground truths
 valid = np.ones((batch_size, 1))
 fake = np.zeros((batch_size, 1))
 
 for epoch in range(epochs):
 
  # ---------------------
  # 训练判别器
  # ---------------------
 
  # data.shape[0]为数据集的数量,随机生成batch_size个数量的随机数,作为数据的索引
  idx = np.random.randint(0, data.shape[0], batch_size)
  
  #从数据集随机挑选batch_size个数据,作为一个批次训练
  imgs = data[idx]
  
  #噪音维度(batch_size,100)
  noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
 
  # 由生成器根据噪音生成假的图片
  gen_imgs = self.generator.predict(noise)
 
  # 训练判别器,判别器希望真实图片,打上标签1,假的图片打上标签0
  d_loss_real = self.discriminator.train_on_batch(imgs, valid)
  d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
  d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
 
  # ---------------------
  # 训练生成器
  # ---------------------
 
  noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
 
  # Train the generator (to have the discriminator label samples as valid)
  g_loss = self.combined.train_on_batch(noise, valid)
 
  # 打印loss值
  print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
 
  # 没sample_interval个epoch保存一次生成图片
  if epoch % sample_interval == 0:
  self.sample_images(epoch)
  if not os.path.exists("keras_model"):
   os.makedirs("keras_model")
  self.generator.save_weights("keras_model/G_model%d.hdf5" % epoch,True)
  self.discriminator.save_weights("keras_model/D_model%d.hdf5" %epoch,True)
 
 def sample_images(self, epoch):
 r, c = 10, 10
 # 重新生成一批噪音,维度为(100,100)
 noise = np.random.normal(0, 1, (r * c, self.latent_dim))
 gen_imgs = self.generator.predict(noise)
 
 # 将生成的图片重新归整到0-1之间
 gen = 0.5 * gen_imgs + 0.5
 gen = gen.reshape(-1,3,60)
 
 fig,axs = plt.subplots(r,c) 
 cnt = 0 
 for i in range(r): 
  for j in range(c): 
  xy = gen[cnt] 
  for k in range(len(xy)): 
   x = xy[k][0:30] 
   y = xy[k][30:60] 
   if k == 0: 
   axs[i,j].plot(x,y,color='blue') 
   if k == 1: 
   axs[i,j].plot(x,y,color='red') 
   if k == 2: 
   axs[i,j].plot(x,y,color='green') 
   plt.xlim(0.,1.)
   plt.ylim(0.,1.)
   plt.xticks(np.arange(0,1,0.1))
   plt.xticks(np.arange(0,1,0.1))
   axs[i,j].axis('off')
  cnt += 1 
 if not os.path.exists("keras_imgs"):
  os.makedirs("keras_imgs")
 fig.savefig("keras_imgs/%d.png" % epoch)
 plt.close()
 
 def test(self,gen_nums=100,save=False):
 self.generator.load_weights("keras_model/G_model4000.hdf5",by_name=True)
 self.discriminator.load_weights("keras_model/D_model4000.hdf5",by_name=True)
 noise = np.random.normal(0,1,(gen_nums,self.latent_dim))
 gen = self.generator.predict(noise)
 gen = 0.5 * gen + 0.5
 gen = gen.reshape(-1,3,60)
 print(gen.shape)
 ###############################################################
 #直接可视化生成图片
 if save:
  for i in range(0,len(gen)):
  plt.figure(figsize=(128,128),dpi=1)
  plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue',linewidth=300)
  plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red',linewidth=300)
  plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green',linewidth=300)
  plt.axis('off')
  plt.xlim(0.,1.)
  plt.ylim(0.,1.)
  plt.xticks(np.arange(0,1,0.1))
  plt.yticks(np.arange(0,1,0.1))
  if not os.path.exists("keras_gen"):
   os.makedirs("keras_gen")
  plt.savefig("keras_gen"+os.sep+str(i)+'.jpg',dpi=1)
  plt.close()
 ##################################################################
 #重整图片到0-1
 else:
  for i in range(len(gen)):
  plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue')
  plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red')
  plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green')
  plt.xlim(0.,1.)
  plt.ylim(0.,1.)
  plt.xticks(np.arange(0,1,0.1))
  plt.xticks(np.arange(0,1,0.1))
  plt.show()
 
if __name__ == '__main__':
 gan = GAN()
 gan.train(epochs=300000, batch_size=32, sample_interval=2000)
# gan.test(save=True)

以上这篇Keras自动下载的数据集/模型存放位置介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现自动重启本程序的方法
Jul 09 Python
Python 实现字符串中指定位置插入一个字符
May 02 Python
浅谈python 导入模块和解决文件句柄找不到问题
Dec 15 Python
Centos部署django服务nginx+uwsgi的方法
Jan 02 Python
解决Django生产环境无法加载静态文件问题的解决
Apr 23 Python
numpy.where() 用法详解
May 27 Python
FFrpc python客户端lib使用解析
Aug 24 Python
wxPython电子表格功能wx.grid实例教程
Nov 19 Python
python + selenium 刷B站播放量的实例代码
Jun 12 Python
python属于解释型语言么
Jun 15 Python
基于CentOS搭建Python Django环境过程解析
Aug 24 Python
python中翻译功能translate模块实现方法
Dec 17 Python
Python应用实现处理excel数据过程解析
Jun 19 #Python
在tensorflow以及keras安装目录查询操作(windows下)
Jun 19 #Python
Scrapy框架介绍之Puppeteer渲染的使用
Jun 19 #Python
Python内置方法和属性应用:反射和单例(推荐)
Jun 19 #Python
Python应用实现双指数函数及拟合代码实例
Jun 19 #Python
PyQT5 实现快捷键复制表格数据的方法示例
Jun 19 #Python
如何在keras中添加自己的优化器(如adam等)
Jun 19 #Python
You might like
PHP伪静态页面函数附使用方法
2008/06/20 PHP
解析PayPal支付接口的PHP开发方式
2010/11/28 PHP
PHP请求Socket接口测试实例
2016/08/12 PHP
IE下使用cloneNode注意事项分享
2012/11/22 Javascript
JavaScript prototype属性深入介绍
2012/11/27 Javascript
Javascript浮点数乘积运算出现多位小数的解决方法
2014/02/17 Javascript
JS数组的常见用法实例
2015/02/10 Javascript
JS使用cookie实现DIV提示框只显示一次的方法
2015/11/05 Javascript
javascript下拉列表菜单的实现方法
2015/11/18 Javascript
JS中将多个逗号替换为一个逗号的实现代码
2017/06/23 Javascript
ES6中Array.copyWithin()函数的用法实例详解
2017/09/16 Javascript
Vue + better-scroll 实现移动端字母索引导航功能
2018/05/07 Javascript
vue-cli 默认路由再子路由选中下的选中状态问题及解决代码
2018/09/06 Javascript
vue2 设置router-view默认路径的实例
2018/09/20 Javascript
vue实现div拖拽互换位置
2020/07/29 Javascript
Vue路由守卫之路由独享守卫
2019/09/25 Javascript
这15个Vue指令,让你的项目开发爽到爆
2019/10/11 Javascript
Vue 实现CLI 3.0 + momentjs + lodash打包时优化
2019/11/13 Javascript
微信小程序实现上拉加载功能
2019/11/20 Javascript
javascript实现倒计时提示框
2021/03/02 Javascript
Python文件操作类操作实例详解
2014/07/11 Python
Python学习笔记整理3之输入输出、python eval函数
2015/12/14 Python
django中related_name的用法说明
2020/05/20 Python
创业计划书——互联网商机
2014/01/12 职场文书
新娘父亲婚礼致辞
2014/01/16 职场文书
《巨人的花园》教学反思
2014/02/12 职场文书
知识竞赛活动方案
2014/02/18 职场文书
祖国在我心中演讲稿200字
2014/08/28 职场文书
食品安全承诺书范文
2014/08/29 职场文书
2015年南京大屠杀纪念日活动总结
2015/03/24 职场文书
奥巴马开学演讲观后感
2015/06/12 职场文书
商务宴会祝酒词
2015/08/11 职场文书
学校趣味运动会开幕词
2016/03/04 职场文书
Python离线安装openpyxl模块的步骤
2021/03/30 Python
Python Pandas 删除列操作
2022/03/16 Python
画错魏国疆域啦!《派对咖孔明》动画因作画失误于官网致歉
2022/04/07 日漫