Keras自动下载的数据集/模型存放位置介绍


Posted in Python onJune 19, 2020

Mac

# 数据集
~/.keras/datasets/

# 模型
~/.keras/models/

Linux

# 数据集
~/.keras/datasets/

Windows

# win10
C:\Users\user_name\.keras\datasets

补充知识:Keras_gan生成自己的数据,并保存模型

我就废话不多说了,大家还是直接看代码吧~

from __future__ import print_function, division
 
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import os
import matplotlib.pyplot as plt
import sys
import numpy as np
 
class GAN():
 def __init__(self):
 self.img_rows = 3
 self.img_cols = 60
 self.channels = 1
 self.img_shape = (self.img_rows, self.img_cols, self.channels)
 self.latent_dim = 100
 
 optimizer = Adam(0.0002, 0.5)
 
 # 构建和编译判别器
 self.discriminator = self.build_discriminator()
 self.discriminator.compile(loss='binary_crossentropy',
  optimizer=optimizer,
  metrics=['accuracy'])
 
 # 构建生成器
 self.generator = self.build_generator()
 
 # 生成器输入噪音,生成假的图片
 z = Input(shape=(self.latent_dim,))
 img = self.generator(z)
 
 # 为了组合模型,只训练生成器
 self.discriminator.trainable = False
 
 # 判别器将生成的图像作为输入并确定有效性
 validity = self.discriminator(img)
 
 # The combined model (stacked generator and discriminator)
 # 训练生成器骗过判别器
 self.combined = Model(z, validity)
 self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
 
 def build_generator(self):
 
 model = Sequential()
 model.add(Dense(64, input_dim=self.latent_dim))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(128))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(256))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(512))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(1024))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 #np.prod(self.img_shape)=3x60x1
 model.add(Dense(np.prod(self.img_shape), activation='tanh'))
 model.add(Reshape(self.img_shape))
 
 model.summary()
 
 noise = Input(shape=(self.latent_dim,))
 img = model(noise)
 
 #输入噪音,输出图片
 return Model(noise, img)
 
 def build_discriminator(self):
 
 model = Sequential()
 
 model.add(Flatten(input_shape=self.img_shape))
 
 model.add(Dense(1024))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(512))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(256))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(128))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(64))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(1, activation='sigmoid'))
 model.summary()
 
 img = Input(shape=self.img_shape)
 validity = model(img)
 return Model(img, validity)
 
 def train(self, epochs, batch_size=128, sample_interval=50):
 
 ############################################################
 #自己数据集此部分需要更改
 # 加载数据集
 data = np.load('data/相对大小分叉.npy') 
 data = data[:,:,0:60]
 # 归一化到-1到1
 data = data * 2 - 1
 data = np.expand_dims(data, axis=3)
 ############################################################
 
 # Adversarial ground truths
 valid = np.ones((batch_size, 1))
 fake = np.zeros((batch_size, 1))
 
 for epoch in range(epochs):
 
  # ---------------------
  # 训练判别器
  # ---------------------
 
  # data.shape[0]为数据集的数量,随机生成batch_size个数量的随机数,作为数据的索引
  idx = np.random.randint(0, data.shape[0], batch_size)
  
  #从数据集随机挑选batch_size个数据,作为一个批次训练
  imgs = data[idx]
  
  #噪音维度(batch_size,100)
  noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
 
  # 由生成器根据噪音生成假的图片
  gen_imgs = self.generator.predict(noise)
 
  # 训练判别器,判别器希望真实图片,打上标签1,假的图片打上标签0
  d_loss_real = self.discriminator.train_on_batch(imgs, valid)
  d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
  d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
 
  # ---------------------
  # 训练生成器
  # ---------------------
 
  noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
 
  # Train the generator (to have the discriminator label samples as valid)
  g_loss = self.combined.train_on_batch(noise, valid)
 
  # 打印loss值
  print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
 
  # 没sample_interval个epoch保存一次生成图片
  if epoch % sample_interval == 0:
  self.sample_images(epoch)
  if not os.path.exists("keras_model"):
   os.makedirs("keras_model")
  self.generator.save_weights("keras_model/G_model%d.hdf5" % epoch,True)
  self.discriminator.save_weights("keras_model/D_model%d.hdf5" %epoch,True)
 
 def sample_images(self, epoch):
 r, c = 10, 10
 # 重新生成一批噪音,维度为(100,100)
 noise = np.random.normal(0, 1, (r * c, self.latent_dim))
 gen_imgs = self.generator.predict(noise)
 
 # 将生成的图片重新归整到0-1之间
 gen = 0.5 * gen_imgs + 0.5
 gen = gen.reshape(-1,3,60)
 
 fig,axs = plt.subplots(r,c) 
 cnt = 0 
 for i in range(r): 
  for j in range(c): 
  xy = gen[cnt] 
  for k in range(len(xy)): 
   x = xy[k][0:30] 
   y = xy[k][30:60] 
   if k == 0: 
   axs[i,j].plot(x,y,color='blue') 
   if k == 1: 
   axs[i,j].plot(x,y,color='red') 
   if k == 2: 
   axs[i,j].plot(x,y,color='green') 
   plt.xlim(0.,1.)
   plt.ylim(0.,1.)
   plt.xticks(np.arange(0,1,0.1))
   plt.xticks(np.arange(0,1,0.1))
   axs[i,j].axis('off')
  cnt += 1 
 if not os.path.exists("keras_imgs"):
  os.makedirs("keras_imgs")
 fig.savefig("keras_imgs/%d.png" % epoch)
 plt.close()
 
 def test(self,gen_nums=100,save=False):
 self.generator.load_weights("keras_model/G_model4000.hdf5",by_name=True)
 self.discriminator.load_weights("keras_model/D_model4000.hdf5",by_name=True)
 noise = np.random.normal(0,1,(gen_nums,self.latent_dim))
 gen = self.generator.predict(noise)
 gen = 0.5 * gen + 0.5
 gen = gen.reshape(-1,3,60)
 print(gen.shape)
 ###############################################################
 #直接可视化生成图片
 if save:
  for i in range(0,len(gen)):
  plt.figure(figsize=(128,128),dpi=1)
  plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue',linewidth=300)
  plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red',linewidth=300)
  plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green',linewidth=300)
  plt.axis('off')
  plt.xlim(0.,1.)
  plt.ylim(0.,1.)
  plt.xticks(np.arange(0,1,0.1))
  plt.yticks(np.arange(0,1,0.1))
  if not os.path.exists("keras_gen"):
   os.makedirs("keras_gen")
  plt.savefig("keras_gen"+os.sep+str(i)+'.jpg',dpi=1)
  plt.close()
 ##################################################################
 #重整图片到0-1
 else:
  for i in range(len(gen)):
  plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue')
  plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red')
  plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green')
  plt.xlim(0.,1.)
  plt.ylim(0.,1.)
  plt.xticks(np.arange(0,1,0.1))
  plt.xticks(np.arange(0,1,0.1))
  plt.show()
 
if __name__ == '__main__':
 gan = GAN()
 gan.train(epochs=300000, batch_size=32, sample_interval=2000)
# gan.test(save=True)

以上这篇Keras自动下载的数据集/模型存放位置介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python编程实现的简单神经网络算法示例
Jan 26 Python
Python 错误和异常代码详解
Jan 29 Python
Django中redis的使用方法(包括安装、配置、启动)
Feb 21 Python
解决pip install的时候报错timed out的问题
Jun 12 Python
Python GUI编程 文本弹窗的实例
Jun 11 Python
详解将Python程序(.py)转换为Windows可执行文件(.exe)
Jul 19 Python
Python和Anaconda和Pycharm安装教程图文详解
Feb 04 Python
Pycharm中安装Pygal并使用Pygal模拟掷骰子(推荐)
Apr 08 Python
keras Lambda自定义层实现数据的切片方式,Lambda传参数
Jun 11 Python
python关于倒排列的知识点总结
Oct 13 Python
python flask开发的简单基金查询工具
Jun 02 Python
Python用any()函数检查字符串中的字母以及如何使用all()函数
Apr 14 Python
Python应用实现处理excel数据过程解析
Jun 19 #Python
在tensorflow以及keras安装目录查询操作(windows下)
Jun 19 #Python
Scrapy框架介绍之Puppeteer渲染的使用
Jun 19 #Python
Python内置方法和属性应用:反射和单例(推荐)
Jun 19 #Python
Python应用实现双指数函数及拟合代码实例
Jun 19 #Python
PyQT5 实现快捷键复制表格数据的方法示例
Jun 19 #Python
如何在keras中添加自己的优化器(如adam等)
Jun 19 #Python
You might like
安健A254立体声随身听的分析与打磨
2021/03/02 无线电
基于mysql的bbs设计(五)
2006/10/09 PHP
ThinkPHP查询中的魔术方法简述
2014/06/25 PHP
对于ThinkPHP框架早期版本的一个SQL注入漏洞详细分析
2014/07/04 PHP
php header函数的常用http头设置
2015/06/25 PHP
利用 fsockopen() 函数开放端口扫描器的实例
2017/08/19 PHP
PHP开发中解决并发问题的几种实现方法分析
2017/11/13 PHP
laravel withCount 统计关联数量的方法
2019/10/10 PHP
js 连接数据库如何操作数据库中的数据
2012/11/23 Javascript
原生JavaScript实现连连看游戏(附源码)
2013/11/05 Javascript
js跑步算法的实现代码
2013/12/04 Javascript
javascript获取web应用根目录的方法
2014/02/12 Javascript
js对字符串进行编码的方法总结(推荐)
2016/11/10 Javascript
JS 终止执行的实现方法
2016/11/24 Javascript
浅析BootStrap中Modal(模态框)使用心得
2016/12/24 Javascript
jQuery Ajax前后端使用JSON进行交互示例
2017/03/17 Javascript
JS ES6中setTimeout函数的执行上下文示例
2017/04/27 Javascript
react native带索引的城市列表组件的实例代码
2017/08/08 Javascript
JavaScript框架Angular和React深度对比
2017/11/20 Javascript
Python自定义函数的创建、调用和函数的参数详解
2014/03/11 Python
用Python编写简单的定时器的方法
2015/05/02 Python
Python基于checksum计算文件是否相同的方法
2015/07/09 Python
Python简单生成8位随机密码的方法
2017/05/24 Python
使用python opencv对目录下图片进行去重的方法
2019/01/12 Python
Python常用爬虫代码总结方便查询
2019/02/25 Python
Python3内置模块pprint让打印比print更美观详解
2019/06/02 Python
python IDLE添加行号显示教程
2020/04/25 Python
mac安装python3后使用pip和pip3的区别说明
2020/09/01 Python
浅谈pc和移动端的响应式的使用
2019/01/03 HTML / CSS
HTML5 Blob 实现文件下载功能的示例代码
2019/11/29 HTML / CSS
GIVENCHY纪梵希官方旗舰店:高定彩妆与贵族护肤品
2018/04/16 全球购物
初级党校心得体会
2014/09/11 职场文书
职工擅自离岗检讨书
2014/09/23 职场文书
2015元旦晚会主持词(开场白+结束语)
2014/12/14 职场文书
安全事故隐患排查治理制度
2015/08/05 职场文书
2016个人廉洁自律承诺书
2016/03/25 职场文书