Keras自动下载的数据集/模型存放位置介绍


Posted in Python onJune 19, 2020

Mac

# 数据集
~/.keras/datasets/

# 模型
~/.keras/models/

Linux

# 数据集
~/.keras/datasets/

Windows

# win10
C:\Users\user_name\.keras\datasets

补充知识:Keras_gan生成自己的数据,并保存模型

我就废话不多说了,大家还是直接看代码吧~

from __future__ import print_function, division
 
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import os
import matplotlib.pyplot as plt
import sys
import numpy as np
 
class GAN():
 def __init__(self):
 self.img_rows = 3
 self.img_cols = 60
 self.channels = 1
 self.img_shape = (self.img_rows, self.img_cols, self.channels)
 self.latent_dim = 100
 
 optimizer = Adam(0.0002, 0.5)
 
 # 构建和编译判别器
 self.discriminator = self.build_discriminator()
 self.discriminator.compile(loss='binary_crossentropy',
  optimizer=optimizer,
  metrics=['accuracy'])
 
 # 构建生成器
 self.generator = self.build_generator()
 
 # 生成器输入噪音,生成假的图片
 z = Input(shape=(self.latent_dim,))
 img = self.generator(z)
 
 # 为了组合模型,只训练生成器
 self.discriminator.trainable = False
 
 # 判别器将生成的图像作为输入并确定有效性
 validity = self.discriminator(img)
 
 # The combined model (stacked generator and discriminator)
 # 训练生成器骗过判别器
 self.combined = Model(z, validity)
 self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
 
 def build_generator(self):
 
 model = Sequential()
 model.add(Dense(64, input_dim=self.latent_dim))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(128))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(256))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(512))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(1024))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 #np.prod(self.img_shape)=3x60x1
 model.add(Dense(np.prod(self.img_shape), activation='tanh'))
 model.add(Reshape(self.img_shape))
 
 model.summary()
 
 noise = Input(shape=(self.latent_dim,))
 img = model(noise)
 
 #输入噪音,输出图片
 return Model(noise, img)
 
 def build_discriminator(self):
 
 model = Sequential()
 
 model.add(Flatten(input_shape=self.img_shape))
 
 model.add(Dense(1024))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(512))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(256))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(128))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(64))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(1, activation='sigmoid'))
 model.summary()
 
 img = Input(shape=self.img_shape)
 validity = model(img)
 return Model(img, validity)
 
 def train(self, epochs, batch_size=128, sample_interval=50):
 
 ############################################################
 #自己数据集此部分需要更改
 # 加载数据集
 data = np.load('data/相对大小分叉.npy') 
 data = data[:,:,0:60]
 # 归一化到-1到1
 data = data * 2 - 1
 data = np.expand_dims(data, axis=3)
 ############################################################
 
 # Adversarial ground truths
 valid = np.ones((batch_size, 1))
 fake = np.zeros((batch_size, 1))
 
 for epoch in range(epochs):
 
  # ---------------------
  # 训练判别器
  # ---------------------
 
  # data.shape[0]为数据集的数量,随机生成batch_size个数量的随机数,作为数据的索引
  idx = np.random.randint(0, data.shape[0], batch_size)
  
  #从数据集随机挑选batch_size个数据,作为一个批次训练
  imgs = data[idx]
  
  #噪音维度(batch_size,100)
  noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
 
  # 由生成器根据噪音生成假的图片
  gen_imgs = self.generator.predict(noise)
 
  # 训练判别器,判别器希望真实图片,打上标签1,假的图片打上标签0
  d_loss_real = self.discriminator.train_on_batch(imgs, valid)
  d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
  d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
 
  # ---------------------
  # 训练生成器
  # ---------------------
 
  noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
 
  # Train the generator (to have the discriminator label samples as valid)
  g_loss = self.combined.train_on_batch(noise, valid)
 
  # 打印loss值
  print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
 
  # 没sample_interval个epoch保存一次生成图片
  if epoch % sample_interval == 0:
  self.sample_images(epoch)
  if not os.path.exists("keras_model"):
   os.makedirs("keras_model")
  self.generator.save_weights("keras_model/G_model%d.hdf5" % epoch,True)
  self.discriminator.save_weights("keras_model/D_model%d.hdf5" %epoch,True)
 
 def sample_images(self, epoch):
 r, c = 10, 10
 # 重新生成一批噪音,维度为(100,100)
 noise = np.random.normal(0, 1, (r * c, self.latent_dim))
 gen_imgs = self.generator.predict(noise)
 
 # 将生成的图片重新归整到0-1之间
 gen = 0.5 * gen_imgs + 0.5
 gen = gen.reshape(-1,3,60)
 
 fig,axs = plt.subplots(r,c) 
 cnt = 0 
 for i in range(r): 
  for j in range(c): 
  xy = gen[cnt] 
  for k in range(len(xy)): 
   x = xy[k][0:30] 
   y = xy[k][30:60] 
   if k == 0: 
   axs[i,j].plot(x,y,color='blue') 
   if k == 1: 
   axs[i,j].plot(x,y,color='red') 
   if k == 2: 
   axs[i,j].plot(x,y,color='green') 
   plt.xlim(0.,1.)
   plt.ylim(0.,1.)
   plt.xticks(np.arange(0,1,0.1))
   plt.xticks(np.arange(0,1,0.1))
   axs[i,j].axis('off')
  cnt += 1 
 if not os.path.exists("keras_imgs"):
  os.makedirs("keras_imgs")
 fig.savefig("keras_imgs/%d.png" % epoch)
 plt.close()
 
 def test(self,gen_nums=100,save=False):
 self.generator.load_weights("keras_model/G_model4000.hdf5",by_name=True)
 self.discriminator.load_weights("keras_model/D_model4000.hdf5",by_name=True)
 noise = np.random.normal(0,1,(gen_nums,self.latent_dim))
 gen = self.generator.predict(noise)
 gen = 0.5 * gen + 0.5
 gen = gen.reshape(-1,3,60)
 print(gen.shape)
 ###############################################################
 #直接可视化生成图片
 if save:
  for i in range(0,len(gen)):
  plt.figure(figsize=(128,128),dpi=1)
  plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue',linewidth=300)
  plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red',linewidth=300)
  plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green',linewidth=300)
  plt.axis('off')
  plt.xlim(0.,1.)
  plt.ylim(0.,1.)
  plt.xticks(np.arange(0,1,0.1))
  plt.yticks(np.arange(0,1,0.1))
  if not os.path.exists("keras_gen"):
   os.makedirs("keras_gen")
  plt.savefig("keras_gen"+os.sep+str(i)+'.jpg',dpi=1)
  plt.close()
 ##################################################################
 #重整图片到0-1
 else:
  for i in range(len(gen)):
  plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue')
  plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red')
  plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green')
  plt.xlim(0.,1.)
  plt.ylim(0.,1.)
  plt.xticks(np.arange(0,1,0.1))
  plt.xticks(np.arange(0,1,0.1))
  plt.show()
 
if __name__ == '__main__':
 gan = GAN()
 gan.train(epochs=300000, batch_size=32, sample_interval=2000)
# gan.test(save=True)

以上这篇Keras自动下载的数据集/模型存放位置介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python翻译软件实现代码(使用google api完成)
Nov 26 Python
Python的条件语句与运算符优先级详解
Oct 13 Python
利用django-suit模板添加自定义的菜单、页面及设置访问权限
Jul 13 Python
Python 保存矩阵为Excel的实现方法
Jan 28 Python
树莓派采用socket方式文件传输(python)
Jun 22 Python
python tkinter 设置窗口大小不可缩放实例
Mar 04 Python
Keras预训练的ImageNet模型实现分类操作
Jul 07 Python
python 安装移动复制第三方库操作
Jul 13 Python
全面介绍python中很常用的单元测试框架unitest
Dec 14 Python
python 装饰器的基本使用
Jan 13 Python
Pycharm 解决自动格式化冲突的设置操作
Jan 15 Python
Python中tqdm的使用和例子
Sep 23 Python
Python应用实现处理excel数据过程解析
Jun 19 #Python
在tensorflow以及keras安装目录查询操作(windows下)
Jun 19 #Python
Scrapy框架介绍之Puppeteer渲染的使用
Jun 19 #Python
Python内置方法和属性应用:反射和单例(推荐)
Jun 19 #Python
Python应用实现双指数函数及拟合代码实例
Jun 19 #Python
PyQT5 实现快捷键复制表格数据的方法示例
Jun 19 #Python
如何在keras中添加自己的优化器(如adam等)
Jun 19 #Python
You might like
如何用phpmyadmin设置mysql数据库用户的权限
2012/01/09 PHP
php连接函数implode与分割explode的深入解析
2013/06/26 PHP
PHP通过GD库实现验证码功能示例
2019/02/23 PHP
js监听表单value的修改同步问题,跨浏览器支持
2009/12/31 Javascript
几个比较经典常用的jQuery小技巧
2010/03/01 Javascript
关于html+ashx开发中几个问题的解决方法
2011/07/18 Javascript
原始的js代码和jquery对比体会
2013/09/10 Javascript
JS实现倒计时和文字滚动的效果实例
2014/10/29 Javascript
AngularJS基础学习笔记之简单介绍
2015/05/10 Javascript
AngularJS中使用HTML5手机摄像头拍照
2016/02/22 Javascript
简单理解vue中el、template、replace元素
2016/10/27 Javascript
微信小程序 开发经验整理
2017/02/15 Javascript
Angular2平滑升级到Angular4的步骤详解
2017/03/29 Javascript
详解js正则表达式验证时间格式xxxx-xx-xx形式
2018/02/09 Javascript
Vue验证码60秒倒计时功能简单实例代码
2018/06/22 Javascript
基于VUE的v-charts的曲线显示功能
2019/10/01 Javascript
JavaScript异步操作的几种常见处理方法实例总结
2020/05/11 Javascript
VUE项目axios请求头更改Content-Type操作
2020/07/24 Javascript
Nuxt.js 静态资源和打包的操作
2020/11/06 Javascript
全网小程序接口请求封装实例代码
2020/11/06 Javascript
python ElementTree 基本读操作示例
2009/04/09 Python
Python常用小技巧总结
2015/06/01 Python
Python实现合并同一个文件夹下所有txt文件的方法示例
2018/04/26 Python
python调用webservice接口的实现
2019/07/12 Python
树莓派3 搭建 django 服务器的实例
2019/08/29 Python
树莓派4B安装Tensorflow的方法步骤
2020/07/16 Python
CSS3 实现弹跳的小球动画
2020/10/26 HTML / CSS
美国半成品食材配送服务商:Home Chef
2018/01/25 全球购物
印度排名第一的蛋糕、鲜花和礼品送货:Winni
2019/08/02 全球购物
2014国培学习感言
2014/03/05 职场文书
文案策划求职信
2014/03/18 职场文书
中药学专业毕业生推荐信
2014/07/10 职场文书
财务工作疏忽检讨书
2014/09/11 职场文书
酒店前台岗位职责
2015/04/16 职场文书
员工开除通知书
2015/04/25 职场文书
故意杀人罪辩护词
2015/05/21 职场文书