Keras自动下载的数据集/模型存放位置介绍


Posted in Python onJune 19, 2020

Mac

# 数据集
~/.keras/datasets/

# 模型
~/.keras/models/

Linux

# 数据集
~/.keras/datasets/

Windows

# win10
C:\Users\user_name\.keras\datasets

补充知识:Keras_gan生成自己的数据,并保存模型

我就废话不多说了,大家还是直接看代码吧~

from __future__ import print_function, division
 
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import os
import matplotlib.pyplot as plt
import sys
import numpy as np
 
class GAN():
 def __init__(self):
 self.img_rows = 3
 self.img_cols = 60
 self.channels = 1
 self.img_shape = (self.img_rows, self.img_cols, self.channels)
 self.latent_dim = 100
 
 optimizer = Adam(0.0002, 0.5)
 
 # 构建和编译判别器
 self.discriminator = self.build_discriminator()
 self.discriminator.compile(loss='binary_crossentropy',
  optimizer=optimizer,
  metrics=['accuracy'])
 
 # 构建生成器
 self.generator = self.build_generator()
 
 # 生成器输入噪音,生成假的图片
 z = Input(shape=(self.latent_dim,))
 img = self.generator(z)
 
 # 为了组合模型,只训练生成器
 self.discriminator.trainable = False
 
 # 判别器将生成的图像作为输入并确定有效性
 validity = self.discriminator(img)
 
 # The combined model (stacked generator and discriminator)
 # 训练生成器骗过判别器
 self.combined = Model(z, validity)
 self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
 
 def build_generator(self):
 
 model = Sequential()
 model.add(Dense(64, input_dim=self.latent_dim))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(128))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(256))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(512))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 model.add(Dense(1024))
 model.add(LeakyReLU(alpha=0.2))
 model.add(BatchNormalization(momentum=0.8))
 
 #np.prod(self.img_shape)=3x60x1
 model.add(Dense(np.prod(self.img_shape), activation='tanh'))
 model.add(Reshape(self.img_shape))
 
 model.summary()
 
 noise = Input(shape=(self.latent_dim,))
 img = model(noise)
 
 #输入噪音,输出图片
 return Model(noise, img)
 
 def build_discriminator(self):
 
 model = Sequential()
 
 model.add(Flatten(input_shape=self.img_shape))
 
 model.add(Dense(1024))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(512))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(256))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(128))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(64))
 model.add(LeakyReLU(alpha=0.2))
 
 model.add(Dense(1, activation='sigmoid'))
 model.summary()
 
 img = Input(shape=self.img_shape)
 validity = model(img)
 return Model(img, validity)
 
 def train(self, epochs, batch_size=128, sample_interval=50):
 
 ############################################################
 #自己数据集此部分需要更改
 # 加载数据集
 data = np.load('data/相对大小分叉.npy') 
 data = data[:,:,0:60]
 # 归一化到-1到1
 data = data * 2 - 1
 data = np.expand_dims(data, axis=3)
 ############################################################
 
 # Adversarial ground truths
 valid = np.ones((batch_size, 1))
 fake = np.zeros((batch_size, 1))
 
 for epoch in range(epochs):
 
  # ---------------------
  # 训练判别器
  # ---------------------
 
  # data.shape[0]为数据集的数量,随机生成batch_size个数量的随机数,作为数据的索引
  idx = np.random.randint(0, data.shape[0], batch_size)
  
  #从数据集随机挑选batch_size个数据,作为一个批次训练
  imgs = data[idx]
  
  #噪音维度(batch_size,100)
  noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
 
  # 由生成器根据噪音生成假的图片
  gen_imgs = self.generator.predict(noise)
 
  # 训练判别器,判别器希望真实图片,打上标签1,假的图片打上标签0
  d_loss_real = self.discriminator.train_on_batch(imgs, valid)
  d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
  d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
 
  # ---------------------
  # 训练生成器
  # ---------------------
 
  noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
 
  # Train the generator (to have the discriminator label samples as valid)
  g_loss = self.combined.train_on_batch(noise, valid)
 
  # 打印loss值
  print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
 
  # 没sample_interval个epoch保存一次生成图片
  if epoch % sample_interval == 0:
  self.sample_images(epoch)
  if not os.path.exists("keras_model"):
   os.makedirs("keras_model")
  self.generator.save_weights("keras_model/G_model%d.hdf5" % epoch,True)
  self.discriminator.save_weights("keras_model/D_model%d.hdf5" %epoch,True)
 
 def sample_images(self, epoch):
 r, c = 10, 10
 # 重新生成一批噪音,维度为(100,100)
 noise = np.random.normal(0, 1, (r * c, self.latent_dim))
 gen_imgs = self.generator.predict(noise)
 
 # 将生成的图片重新归整到0-1之间
 gen = 0.5 * gen_imgs + 0.5
 gen = gen.reshape(-1,3,60)
 
 fig,axs = plt.subplots(r,c) 
 cnt = 0 
 for i in range(r): 
  for j in range(c): 
  xy = gen[cnt] 
  for k in range(len(xy)): 
   x = xy[k][0:30] 
   y = xy[k][30:60] 
   if k == 0: 
   axs[i,j].plot(x,y,color='blue') 
   if k == 1: 
   axs[i,j].plot(x,y,color='red') 
   if k == 2: 
   axs[i,j].plot(x,y,color='green') 
   plt.xlim(0.,1.)
   plt.ylim(0.,1.)
   plt.xticks(np.arange(0,1,0.1))
   plt.xticks(np.arange(0,1,0.1))
   axs[i,j].axis('off')
  cnt += 1 
 if not os.path.exists("keras_imgs"):
  os.makedirs("keras_imgs")
 fig.savefig("keras_imgs/%d.png" % epoch)
 plt.close()
 
 def test(self,gen_nums=100,save=False):
 self.generator.load_weights("keras_model/G_model4000.hdf5",by_name=True)
 self.discriminator.load_weights("keras_model/D_model4000.hdf5",by_name=True)
 noise = np.random.normal(0,1,(gen_nums,self.latent_dim))
 gen = self.generator.predict(noise)
 gen = 0.5 * gen + 0.5
 gen = gen.reshape(-1,3,60)
 print(gen.shape)
 ###############################################################
 #直接可视化生成图片
 if save:
  for i in range(0,len(gen)):
  plt.figure(figsize=(128,128),dpi=1)
  plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue',linewidth=300)
  plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red',linewidth=300)
  plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green',linewidth=300)
  plt.axis('off')
  plt.xlim(0.,1.)
  plt.ylim(0.,1.)
  plt.xticks(np.arange(0,1,0.1))
  plt.yticks(np.arange(0,1,0.1))
  if not os.path.exists("keras_gen"):
   os.makedirs("keras_gen")
  plt.savefig("keras_gen"+os.sep+str(i)+'.jpg',dpi=1)
  plt.close()
 ##################################################################
 #重整图片到0-1
 else:
  for i in range(len(gen)):
  plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue')
  plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red')
  plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green')
  plt.xlim(0.,1.)
  plt.ylim(0.,1.)
  plt.xticks(np.arange(0,1,0.1))
  plt.xticks(np.arange(0,1,0.1))
  plt.show()
 
if __name__ == '__main__':
 gan = GAN()
 gan.train(epochs=300000, batch_size=32, sample_interval=2000)
# gan.test(save=True)

以上这篇Keras自动下载的数据集/模型存放位置介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python去除列表中重复元素的方法
Mar 20 Python
使用Python脚本实现批量网站存活检测遇到问题及解决方法
Oct 11 Python
CentOS 7下Python 2.7升级至Python3.6.1的实战教程
Jul 06 Python
简单了解OpenCV是个什么东西
Nov 10 Python
Python调用C++,通过Pybind11制作Python接口
Oct 16 Python
python3使用pandas获取股票数据的方法
Dec 22 Python
python 整数越界问题详解
Jun 27 Python
Python:Numpy 求平均向量的实例
Jun 29 Python
Python基于pygame实现单机版五子棋对战
Dec 26 Python
Python run()函数和start()函数的比较和差别介绍
May 03 Python
一篇文章搞懂python的转义字符及用法
Sep 03 Python
python自动打开浏览器下载zip并提取内容写入excel
Jan 04 Python
Python应用实现处理excel数据过程解析
Jun 19 #Python
在tensorflow以及keras安装目录查询操作(windows下)
Jun 19 #Python
Scrapy框架介绍之Puppeteer渲染的使用
Jun 19 #Python
Python内置方法和属性应用:反射和单例(推荐)
Jun 19 #Python
Python应用实现双指数函数及拟合代码实例
Jun 19 #Python
PyQT5 实现快捷键复制表格数据的方法示例
Jun 19 #Python
如何在keras中添加自己的优化器(如adam等)
Jun 19 #Python
You might like
PHP UTF8中文字符截断函数代码
2012/09/11 PHP
PHP基于GD库的缩略图生成代码(支持jpg,gif,png格式)
2014/06/19 PHP
PHP基于cookie与session统计网站访问量并输出显示的方法
2016/01/15 PHP
Laravel框架生命周期与原理分析
2018/06/12 PHP
js函数定时器实现定时读取系统实时连接数
2014/04/30 Javascript
jquery分隔Url的param方法(推荐)
2016/05/25 Javascript
H5移动端适配 Flexible方案
2016/10/24 Javascript
详解js树形控件—zTree使用总结
2016/12/28 Javascript
js实现年月日表单三级联动
2020/04/17 Javascript
JS解决IOS中拍照图片预览旋转90度BUG的问题
2017/09/13 Javascript
使用jQuery实现两个div中按钮互换位置的实例代码
2017/09/21 jQuery
微信小程序开发之好友列表字母列表跳转对应位置
2017/09/26 Javascript
jquery中有哪些api jQuery主要API
2017/11/20 jQuery
JS的函数调用栈stack size的计算方法
2018/06/24 Javascript
Vue $emit $refs子父组件间方法的调用实例
2018/09/12 Javascript
如何对react hooks进行单元测试的方法
2019/08/14 Javascript
Vue+Element-U实现分页显示效果
2020/11/15 Javascript
[02:49]2018DOTA2亚洲邀请赛主赛事决赛日战况回顾 Mineski鏖战5局夺得辉耀
2018/04/10 DOTA
Python显示进度条的方法
2014/09/20 Python
部署Python的框架下的web app的详细教程
2015/04/30 Python
在Python的Tornado框架中实现简单的在线代理的教程
2015/05/02 Python
python判断设备是否联网的方法
2018/06/29 Python
Python中shapefile转换geojson的示例
2019/01/03 Python
Python3 单行多行万能正则匹配方法
2019/01/07 Python
全面介绍python中很常用的单元测试框架unitest
2020/12/14 Python
美国领先的水果篮送货公司和新鲜水果供应商:The Fruit Company
2018/02/13 全球购物
财政局长自荐信范文
2013/12/22 职场文书
厂长岗位职责
2014/02/19 职场文书
原料仓管员岗位职责
2014/04/12 职场文书
暑期教师培训方案
2014/06/07 职场文书
党员三严三实心得体会
2014/10/13 职场文书
群众路线四风对照检查材料
2014/11/04 职场文书
2014年人力资源工作总结
2014/11/19 职场文书
碧霞祠导游词
2015/02/09 职场文书
ROS系统将python包编译为可执行文件的简单步骤
2021/07/25 Python
python全面解析接口返回数据
2022/02/12 Python