softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python获取当前时间的方法
Jan 14 Python
windows下python模拟鼠标点击和键盘输示例
Feb 28 Python
pygame实现弹力球及其变速效果
Jul 03 Python
python中logging库的使用总结
Oct 18 Python
mac下给python3安装requests库和scrapy库的实例
Jun 13 Python
使用python实现快速搭建简易的FTP服务器
Sep 12 Python
Tesserocr库的正确安装方式
Oct 19 Python
实例详解Matlab 与 Python 的区别
Apr 26 Python
Python读写文件基础知识点
Jun 10 Python
python、Matlab求定积分的实现
Nov 20 Python
使用SimpleITK读取和保存NIfTI/DICOM文件实例
Jul 01 Python
Python OpenCV快速入门教程
Apr 17 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
PHP整数取余返回负数的相关解决方法
2014/05/15 PHP
几个优化WordPress中JavaScript加载体验的插件介绍
2015/12/17 PHP
crontab无法执行php的解决方法
2016/01/25 PHP
php+jquery+html实现点击不刷新加载更多的实例代码
2016/08/12 PHP
PHP unset函数原理及使用方法解析
2020/08/14 PHP
老鱼 浅谈javascript面向对象编程
2010/03/04 Javascript
JS 树形递归实例代码
2010/05/18 Javascript
从阶乘函数对比Javascript和C#的异同
2012/05/31 Javascript
jQuery 数据缓存模块进化史详细介绍
2012/11/19 Javascript
js 通用订单代码
2013/12/23 Javascript
判断某个字符在一个字符串中是否存在的js代码
2014/02/28 Javascript
SyntaxHighlighter 3.0.83使用笔记
2015/01/26 Javascript
jquery实现鼠标拖拽滑动效果来选择数字的方法
2015/05/04 Javascript
jQuery添加和删除指定标签的方法
2015/12/16 Javascript
DeviceOne 让你一见钟情的App快速开发平台
2016/02/17 Javascript
zTree插件下拉树使用入门教程
2016/04/11 Javascript
javascript封装addLoadEvent实现页面同时加载执行多个函数的方法
2016/07/25 Javascript
Node.js开发第三方微信公众平台
2017/06/05 Javascript
详解ionic本地相册、拍照、裁剪、上传(单图完全版)
2017/10/10 Javascript
bootstrap模态框嵌套、tabindex属性、去除阴影的示例代码
2017/10/17 Javascript
jQuery中的for循环var与let的区别
2018/04/21 jQuery
js实现mp3录音通过websocket实时传送+简易波形图效果
2020/06/12 Javascript
jQuery实现朋友圈查看图片
2020/09/11 jQuery
实现vuex原理的示例
2020/10/21 Javascript
[03:20]次级联赛厮杀超职业 现超级兵对拆世纪大战
2014/10/30 DOTA
在Python中使用PIL模块对图片进行高斯模糊处理的教程
2015/05/05 Python
浅析python协程相关概念
2018/01/20 Python
ubuntu17.4下为python和python3装上pip的方法
2018/06/12 Python
Python定义二叉树及4种遍历方法实例详解
2018/07/05 Python
使用Template格式化Python字符串的方法
2019/01/22 Python
Django框架自定义模型管理器与元选项用法分析
2019/07/22 Python
Python 复平面绘图实例
2019/11/21 Python
Chupi官网:在爱尔兰手工制作的订婚、结婚戒指和精美珠宝
2020/09/28 全球购物
技校毕业生个人学习的自我评价
2014/02/21 职场文书
房屋鉴定委托书范本
2014/09/23 职场文书
保证书格式
2015/01/16 职场文书