softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python来编写HTTP服务器的超级指南
Feb 18 Python
在阿里云服务器上配置CentOS+Nginx+Python+Flask环境
Jun 18 Python
利用Python为iOS10生成图标和截屏
Sep 24 Python
python字符串过滤性能比较5种方法
Jun 22 Python
Python数据结构与算法之图的基本实现及迭代器实例详解
Dec 12 Python
Python mutiprocessing多线程池pool操作示例
Jan 30 Python
Django实现学生管理系统
Feb 26 Python
解决yum对python依赖版本问题
Jul 05 Python
Django框架之登录后自定义跳转页面的实现方法
Jul 18 Python
pytorch下大型数据集(大型图片)的导入方式
Jan 08 Python
python使用pyecharts库画地图数据可视化的实现
Mar 25 Python
python实现一个简单的贪吃蛇游戏附代码
Jun 28 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
php getimagesize 上传图片的长度和宽度检测代码
2010/05/15 PHP
php命令行使用方法和命令行参数说明
2014/04/08 PHP
thinkphp视图模型查询提示ERR: 1146:Table 'db.pr_order_view' doesn't exist的解决方法
2014/10/30 PHP
PHP的cURL库简介及使用示例
2015/02/06 PHP
php正则表达式获取内容所有链接
2015/07/24 PHP
php求数组全排列,元素所有组合的方法
2016/05/05 PHP
document.getElementById为空或不是对象的解决方法
2010/01/24 Javascript
javascript 图片上一张下一张链接效果代码
2010/03/12 Javascript
js类型检查实现代码
2010/10/29 Javascript
Extjs中TabPane如何嵌套在其他网页中实现思路及代码
2013/01/27 Javascript
JS 两日期相减,获得天数的小例子(兼容IE,FF)
2013/07/01 Javascript
中止javascript执行的方法
2014/02/14 Javascript
JavaScript数组和循环详解
2015/04/27 Javascript
详解JavaScript中的every()方法
2015/06/08 Javascript
jQuery实现鼠标悬停背景翻转的黑色导航菜单代码
2015/09/14 Javascript
jQuery自制提示框tooltip改进版
2016/08/01 Javascript
jQuery插件EasyUI设置datagrid的checkbox为禁用状态的方法
2016/08/05 Javascript
jQuery插件FusionCharts实现的2D饼状图效果【附demo源码下载】
2017/03/03 Javascript
前端自动化开发之Node.js的环境搭建教程
2017/04/01 Javascript
简单谈谈JS中的正则表达式
2017/09/11 Javascript
React实践之Tree组件的使用方法
2017/09/30 Javascript
vue+webpack实现异步组件加载的方法
2018/02/03 Javascript
使用FormData实现上传多个文件
2018/12/04 Javascript
python实现八大排序算法(2)
2017/09/14 Python
python3利用Dlib19.7实现人脸68个特征点标定
2018/02/26 Python
python 实现倒排索引的方法
2018/12/25 Python
python 将html转换为pdf的几种方法
2020/12/29 Python
python 统计list中各个元素出现的次数的几种方法
2021/02/20 Python
美国马匹用品和骑马配件购物网站:Horse.com
2018/01/08 全球购物
意大利在线药房:Farmacia Loreto Gallo
2019/08/09 全球购物
销售经理工作职责范文
2013/12/03 职场文书
学校组织向国旗敬礼活动方案(中小学适用)
2014/09/27 职场文书
年会主持人开场白台词
2015/05/29 职场文书
CSS3 实现NES游戏机的示例代码
2021/04/21 HTML / CSS
Nginx隐藏式跳转(浏览器URL跳转后保持不变)
2022/04/07 Servers
Windows server 2012 配置Telnet以及用法详解
2022/04/28 Servers