OpenCV python sklearn随机超参数搜索的实现


Posted in Python onJanuary 17, 2020

本文介绍了OpenCV python sklearn随机超参数搜索的实现,分享给大家,具体如下:

"""
房价预测数据集 使用sklearn执行超参数搜索
"""
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import tensorflow as tf
from tensorflow_core.python.keras.api._v2 import keras # 不能使用 python
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from scipy.stats import reciprocal

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

# 0.打印导入模块的版本
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, sklearn, pd, tf, keras:
  print("%s version:%s" % (module.__name__, module.__version__))


# 显示学习曲线
def plot_learning_curves(his):
  pd.DataFrame(his.history).plot(figsize=(8, 5))
  plt.grid(True)
  plt.gca().set_ylim(0, 1)
  plt.show()


# 1.加载数据集 california 房价
housing = fetch_california_housing()

print(housing.DESCR)
print(housing.data.shape)
print(housing.target.shape)

# 2.拆分数据集 训练集 验证集 测试集
x_train_all, x_test, y_train_all, y_test = train_test_split(
  housing.data, housing.target, random_state=7)
x_train, x_valid, y_train, y_valid = train_test_split(
  x_train_all, y_train_all, random_state=11)

print(x_train.shape, y_train.shape)
print(x_valid.shape, y_valid.shape)
print(x_test.shape, y_test.shape)

# 3.数据集归一化
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.fit_transform(x_valid)
x_test_scaled = scaler.fit_transform(x_test)


# 创建keras模型
def build_model(hidden_layers=1, # 中间层的参数
        layer_size=30,
        learning_rate=3e-3):
  # 创建网络层
  model = keras.models.Sequential()
  model.add(keras.layers.Dense(layer_size, activation="relu",
                 input_shape=x_train.shape[1:]))
 # 隐藏层设置
  for _ in range(hidden_layers - 1):
    model.add(keras.layers.Dense(layer_size,
                   activation="relu"))
  model.add(keras.layers.Dense(1))

  # 优化器学习率
  optimizer = keras.optimizers.SGD(lr=learning_rate)
  model.compile(loss="mse", optimizer=optimizer)

  return model


def main():
  # RandomizedSearchCV

  # 1.转化为sklearn的model
  sk_learn_model = keras.wrappers.scikit_learn.KerasRegressor(build_model)

  callbacks = [keras.callbacks.EarlyStopping(patience=5, min_delta=1e-2)]

  history = sk_learn_model.fit(x_train_scaled, y_train, epochs=100,
                 validation_data=(x_valid_scaled, y_valid),
                 callbacks=callbacks)
  # 2.定义超参数集合
  # f(x) = 1/(x*log(b/a)) a <= x <= b
  param_distribution = {
    "hidden_layers": [1, 2, 3, 4],
    "layer_size": np.arange(1, 100),
    "learning_rate": reciprocal(1e-4, 1e-2),
  }

  # 3.执行超搜索参数
  # cross_validation:训练集分成n份, n-1训练, 最后一份验证.
  random_search_cv = RandomizedSearchCV(sk_learn_model, param_distribution,
                     n_iter=10,
                     cv=3,
                     n_jobs=1)
  random_search_cv.fit(x_train_scaled, y_train, epochs=100,
             validation_data=(x_valid_scaled, y_valid),
             callbacks=callbacks)
  # 4.显示超参数
  print(random_search_cv.best_params_)
  print(random_search_cv.best_score_)
  print(random_search_cv.best_estimator_)

  model = random_search_cv.best_estimator_.model
  print(model.evaluate(x_test_scaled, y_test))

  # 5.打印模型训练过程
  plot_learning_curves(history)


if __name__ == '__main__':
  main()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的内置函数getattr()介绍及示例
Jul 20 Python
Python学习笔记之解析json的方法分析
Apr 21 Python
Python列表和元组的定义与使用操作示例
Jul 26 Python
pandas中的DataFrame按指定顺序输出所有列的方法
Apr 10 Python
Python使用Dijkstra算法实现求解图中最短路径距离问题详解
May 16 Python
Python生成器的使用方法和示例代码
Mar 04 Python
详解在Python中以绝对路径或者相对路径导入文件的方法
Aug 30 Python
opencv python 图片读取与显示图片窗口未响应问题的解决
Apr 24 Python
Django 用户登陆访问限制实例 @login_required
May 13 Python
python新手学习可变和不可变对象
Jun 11 Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 Python
详解Python中@staticmethod和@classmethod区别及使用示例代码
Dec 14 Python
python numpy 矩阵堆叠实例
Jan 17 #Python
Python利用Scrapy框架爬取豆瓣电影示例
Jan 17 #Python
Python下利用BeautifulSoup解析HTML的实现
Jan 17 #Python
pytorch forward两个参数实例
Jan 17 #Python
Python实现CNN的多通道输入实例
Jan 17 #Python
Python面向对象编程基础实例分析
Jan 17 #Python
通过python实现windows桌面截图代码实例
Jan 17 #Python
You might like
php中var_export与var_dump的区别分析
2010/08/21 PHP
Php中用PDO查询Mysql来避免SQL注入风险的方法
2013/04/25 PHP
PHP实现显示照片exif信息的方法
2014/07/11 PHP
常用PHP框架功能对照表
2014/10/23 PHP
PHP实现的pdo连接数据库并插入数据功能简单示例
2019/03/30 PHP
javascript实现上传图片前的预览(TX的面试题)
2007/08/20 Javascript
jqTransform form表单美化插件使用方法
2012/07/05 Javascript
原生Js实现元素渐隐/渐现(原理为修改元素的css透明度)
2013/06/24 Javascript
判断window.onload是否多次使用的方法
2014/09/21 Javascript
nodejs创建web服务器之hello world程序
2015/08/20 NodeJs
iframe中子父类窗口调用JS的方法及注意事项
2015/08/25 Javascript
基于JS代码实现实时显示系统时间
2016/06/16 Javascript
Bootstrap基本插件学习笔记之模态对话框(16)
2016/12/08 Javascript
JS仿京东移动端手指拨动切换轮播图效果
2020/04/10 Javascript
微信小程序 动态的设置图片的高度和宽度详解及实例代码
2017/02/24 Javascript
vue的mixins属性详解
2018/03/14 Javascript
vue2.0+vuex+localStorage代办事项应用实现详解
2018/05/31 Javascript
JS实现区分中英文并统计字符个数的方法示例
2018/06/09 Javascript
React+Webpack快速上手指南(小结)
2018/08/15 Javascript
JS script脚本中async和defer区别详解
2020/06/24 Javascript
JavaScript语法约定和程序调试原理解析
2020/11/03 Javascript
用python读写excel的方法
2014/11/18 Python
python监控linux内存并写入mongodb(推荐)
2017/09/11 Python
Python实现调度算法代码详解
2017/12/01 Python
pycharm设置鼠标悬停查看方法设置
2019/07/29 Python
Python队列RabbitMQ 使用方法实例记录
2019/08/05 Python
Python GUI学习之登录系统界面篇
2019/08/21 Python
Python实现滑动平均(Moving Average)的例子
2019/08/24 Python
Python手绘可视化工具cutecharts使用实例
2019/12/05 Python
CSS3网格的三个新特性详解
2014/04/04 HTML / CSS
介绍java中初始化块的使用
2012/09/11 面试题
制冷与电控专业应届生求职信
2013/11/11 职场文书
2015年世界环境日演讲稿
2015/03/18 职场文书
解决Golang time.Parse和time.Format的时区问题
2021/04/29 Golang
Node-Red实现MySQL数据库连接的方法
2021/08/07 MySQL
Apache Pulsar集群搭建部署详细过程
2022/02/12 Servers