Keras自动下载的数据集/模型存放位置介绍


Posted in Python onJune 19, 2020

Mac

# 数据集
~/.keras/datasets/

# 模型
~/.keras/models/

Linux

# 数据集
~/.keras/datasets/

Windows

# win10
C:\Users\user_name\.keras\datasets

补充知识:Keras_gan生成自己的数据,并保存模型

我就废话不多说了,大家还是直接看代码吧~

from __future__ import print_function, division
 
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import os
import matplotlib.pyplot as plt
import sys
import numpy as np
 
class GAN():
 def __init__(self):
 self.img_rows = 3
 self.img_cols = 60
 self.channels = 1
 self.img_shape = (self.img_rows, self.img_cols, self.channels)
 self.latent_dim = 100
 
 optimizer = Adam(0.0002, 0.5)
 
 # 构建和编译判别器
 self.discriminator = self.build_discriminator()
 self.discriminator.compile(loss='binary_crossentropy',
  optimizer=optimizer,
  metrics=['accuracy'])
 
 # 构建生成器
 self.generator = self.build_generator()
 
 # 生成器输入噪音,生成假的图片
 z = Input(shape=(self.latent_dim,))
 img = self.generator(z)
 
 # 为了组合模型,只训练生成器
 self.discriminator.trainable = False
 
 # 判别器将生成的图像作为输入并确定有效性
 validity = self.discriminator(img)
 
 # The combined model (stacked generator and discriminator)
 # 训练生成器骗过判别器
 self.combined = Model(z, validity)
 self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
 
 def build_generator(self):
 
 model = Sequential()
 model.add(Dense(64, input_dim=self.latent_dim))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(128))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(256))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(512))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(1024))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 #np.prod(self.img_shape)=3x60x1
 model.add(Dense(np.prod(self.img_shape), activation='tanh'))
 model.add(Reshape(self.img_shape))
 
 model.summary()
 
 noise = Input(shape=(self.latent_dim,))
 img = model(noise)
 
 #输入噪音,输出图片
 return Model(noise, img)
 
 def build_discriminator(self):
 
 model = Sequential()
 
 model.add(Flatten(input_shape=self.img_shape))
 
 model.add(Dense(1024))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(512))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(256))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(128))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(64))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(1, activation='sigmoid'))
 model.summary()
 
 img = Input(shape=self.img_shape)
 validity = model(img)
 return Model(img, validity)
 
 def train(self, epochs, batch_size=128, sample_interval=50):
 
 ############################################################
 #自己数据集此部分需要更改
 # 加载数据集
 data = np.load('data/相对大小分叉.npy') 
 data = data[:,:,0:60]
 # 归一化到-1到1
 data = data * 2 - 1
 data = np.expand_dims(data, axis=3)
 ############################################################
 
 # Adversarial ground truths
 valid = np.ones((batch_size, 1))
 fake = np.zeros((batch_size, 1))
 
 for epoch in range(epochs):
 
  # ---------------------
  # 训练判别器
  # ---------------------
 
  # data.shape[0]为数据集的数量,随机生成batch_size个数量的随机数,作为数据的索引
  idx = np.random.randint(0, data.shape[0], batch_size)
  
  #从数据集随机挑选batch_size个数据,作为一个批次训练
  imgs = data[idx]
  
  #噪音维度(batch_size,100)
  noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
 
  # 由生成器根据噪音生成假的图片
  gen_imgs = self.generator.predict(noise)
 
  # 训练判别器,判别器希望真实图片,打上标签1,假的图片打上标签0
  d_loss_real = self.discriminator.train_on_batch(imgs, valid)
  d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
  d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
 
  # ---------------------
  # 训练生成器
  # ---------------------
 
  noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
 
  # Train the generator (to have the discriminator label samples as valid)
  g_loss = self.combined.train_on_batch(noise, valid)
 
  # 打印loss值
  print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
 
  # 没sample_interval个epoch保存一次生成图片
  if epoch % sample_interval == 0:
  self.sample_images(epoch)
  if not os.path.exists("keras_model"):
   os.makedirs("keras_model")
  self.generator.save_weights("keras_model/G_model%d.hdf5" % epoch,True)
  self.discriminator.save_weights("keras_model/D_model%d.hdf5" %epoch,True)
 
 def sample_images(self, epoch):
 r, c = 10, 10
 # 重新生成一批噪音,维度为(100,100)
 noise = np.random.normal(0, 1, (r * c, self.latent_dim))
 gen_imgs = self.generator.predict(noise)
 
 # 将生成的图片重新归整到0-1之间
 gen = 0.5 * gen_imgs + 0.5
 gen = gen.reshape(-1,3,60)
 
 fig,axs = plt.subplots(r,c) 
 cnt = 0 
 for i in range(r): 
  for j in range(c): 
  xy = gen[cnt] 
  for k in range(len(xy)): 
   x = xy[k][0:30] 
   y = xy[k][30:60] 
   if k == 0: 
   axs[i,j].plot(x,y,color='blue') 
   if k == 1: 
   axs[i,j].plot(x,y,color='red') 
   if k == 2: 
   axs[i,j].plot(x,y,color='green') 
   plt.xlim(0.,1.)
   plt.ylim(0.,1.)
   plt.xticks(np.arange(0,1,0.1))
   plt.xticks(np.arange(0,1,0.1))
   axs[i,j].axis('off')
  cnt += 1 
 if not os.path.exists("keras_imgs"):
  os.makedirs("keras_imgs")
 fig.savefig("keras_imgs/%d.png" % epoch)
 plt.close()
 
 def test(self,gen_nums=100,save=False):
 self.generator.load_weights("keras_model/G_model4000.hdf5",by_name=True)
 self.discriminator.load_weights("keras_model/D_model4000.hdf5",by_name=True)
 noise = np.random.normal(0,1,(gen_nums,self.latent_dim))
 gen = self.generator.predict(noise)
 gen = 0.5 * gen + 0.5
 gen = gen.reshape(-1,3,60)
 print(gen.shape)
 ###############################################################
 #直接可视化生成图片
 if save:
  for i in range(0,len(gen)):
  plt.figure(figsize=(128,128),dpi=1)
  plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue',linewidth=300)
  plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red',linewidth=300)
  plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green',linewidth=300)
  plt.axis('off')
  plt.xlim(0.,1.)
  plt.ylim(0.,1.)
  plt.xticks(np.arange(0,1,0.1))
  plt.yticks(np.arange(0,1,0.1))
  if not os.path.exists("keras_gen"):
   os.makedirs("keras_gen")
  plt.savefig("keras_gen"+os.sep+str(i)+'.jpg',dpi=1)
  plt.close()
 ##################################################################
 #重整图片到0-1
 else:
  for i in range(len(gen)):
  plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue')
  plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red')
  plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green')
  plt.xlim(0.,1.)
  plt.ylim(0.,1.)
  plt.xticks(np.arange(0,1,0.1))
  plt.xticks(np.arange(0,1,0.1))
  plt.show()
 
if __name__ == '__main__':
 gan = GAN()
 gan.train(epochs=300000, batch_size=32, sample_interval=2000)
# gan.test(save=True)

以上这篇Keras自动下载的数据集/模型存放位置介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python将html转成PDF的实现代码(包含中文)
Mar 04 Python
Python自定义函数的创建、调用和函数的参数详解
Mar 11 Python
python通过定义一个类实例作为ftp回调方法
May 04 Python
python学习之编写查询ip程序
Feb 27 Python
python常用函数详解
Sep 13 Python
利用Python将数值型特征进行离散化操作的方法
Nov 06 Python
python 构造三维全零数组的方法
Nov 12 Python
tensor和numpy的互相转换的实现示例
Aug 02 Python
Python numpy线性代数用法实例解析
Nov 15 Python
解决matplotlib.pyplot在Jupyter notebook中不显示图像问题
Apr 22 Python
Python3实现建造者模式的示例代码
Jun 28 Python
解决PyCharm不在run输出运行结果而不是再Console里输出的问题
Sep 21 Python
Python应用实现处理excel数据过程解析
Jun 19 #Python
在tensorflow以及keras安装目录查询操作(windows下)
Jun 19 #Python
Scrapy框架介绍之Puppeteer渲染的使用
Jun 19 #Python
Python内置方法和属性应用:反射和单例(推荐)
Jun 19 #Python
Python应用实现双指数函数及拟合代码实例
Jun 19 #Python
PyQT5 实现快捷键复制表格数据的方法示例
Jun 19 #Python
如何在keras中添加自己的优化器(如adam等)
Jun 19 #Python
You might like
PHP 验证码的实现代码
2011/07/17 PHP
PHP解决URL中文GBK乱码问题的两种方法
2014/06/03 PHP
PHP操作MongoDB实现增删改查功能【附php7操作MongoDB方法】
2018/04/24 PHP
JavaScript语句可以不以;结尾的烦恼
2007/03/08 Javascript
复制本贴标题和地址的js代码
2008/07/01 Javascript
actionscript与javascript的区别
2011/05/25 Javascript
JavaScript中实现map功能代码分享
2015/06/11 Javascript
javascript实现计时器的简单方法
2016/02/21 Javascript
nodejs 终端打印进度条实例代码
2017/04/22 NodeJs
Angular2仿照微信UI实现9张图片上传和预览的示例代码
2017/10/19 Javascript
vue input 输入校验字母数字组合且长度小于30的实现代码
2018/05/16 Javascript
原生javascript实现连连看游戏
2019/01/03 Javascript
Vue+Element实现动态生成新表单并添加验证功能
2019/05/23 Javascript
js实现图片粘贴到网页
2019/12/06 Javascript
[02:28]DOTA2 2015国际邀请赛中国区预选赛首日现场百态
2015/05/26 DOTA
Python 可爱的大小写
2008/09/06 Python
Python Web框架Flask中使用七牛云存储实例
2015/02/08 Python
在Python中操作字典之clear()方法的使用
2015/05/21 Python
Python实现将DOC文档转换为PDF的方法
2015/07/25 Python
Python数据结构与算法之完全树与最小堆实例
2017/12/13 Python
python儿童学游戏编程知识点总结
2019/06/03 Python
Python FTP文件定时自动下载实现过程解析
2019/11/12 Python
python3读取csv文件任意行列代码实例
2020/01/13 Python
Python socket处理client连接过程解析
2020/03/18 Python
python实现交并比IOU教程
2020/04/16 Python
python使用Word2Vec进行情感分析解析
2020/07/31 Python
python爬虫筛选工作实例讲解
2020/11/23 Python
python 三种方法提取pdf中的图片
2021/02/07 Python
初中英语课后反思
2014/04/25 职场文书
师范生求职信
2014/06/14 职场文书
初中优秀教师事迹材料
2014/08/18 职场文书
好的促销活动方案
2014/08/21 职场文书
2015廉洁自律个人总结
2015/02/14 职场文书
升学宴来宾致辞
2015/07/27 职场文书
2016应届毕业生实习心得体会
2015/10/09 职场文书
css display table 自适应高度、宽度问题的解决
2021/05/07 HTML / CSS