softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python将字符串转换成数组的方法
Apr 29 Python
Python环境搭建之OpenCV的步骤方法
Oct 20 Python
浅谈Python处理PDF的方法
Nov 10 Python
Python3中函数参数传递方式实例详解
May 05 Python
Python将视频或者动态图gif逐帧保存为图片的方法
Sep 10 Python
pytorch之添加BN的实现
Jan 06 Python
使用tensorflow实现矩阵分解方式
Feb 07 Python
PYQT5 vscode联合操作qtdesigner的方法
Mar 24 Python
python随机模块random的22种函数(小结)
May 15 Python
python软件测试Jmeter性能测试JDBC Request(结合数据库)的使用详解
Jan 26 Python
python中os.remove()用法及注意事项
Jan 31 Python
上手简单,功能强大的Python爬虫框架——feapder
Apr 27 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
PHP中使用数组实现堆栈数据结构的代码
2012/02/05 PHP
php出现web系统多域名登录失败的解决方法
2014/09/30 PHP
PHP程序员必须清楚的问题汇总
2014/12/18 PHP
详解PHP中的序列化、反序列化操作
2017/03/21 PHP
PHP新特性详解之命名空间、性状与生成器
2017/07/18 PHP
得到文本框选中的文字,动态插入文字的js代码
2007/03/07 Javascript
JS中动态添加事件(绑定事件)的代码
2011/01/09 Javascript
50个比较实用jQuery代码段
2011/09/18 Javascript
JS替换文本域内的回车示例
2014/02/18 Javascript
Javascript 浮点运算精度问题分析与解决
2014/03/26 Javascript
Jquery实现侧边栏跟随滚动条固定(兼容IE6)
2014/04/02 Javascript
Javascript Memoizer浅析
2014/10/16 Javascript
javascript实现在下拉列表中显示多级树形菜单的方法
2015/08/12 Javascript
基于jquery实现导航菜单高亮显示(两种方法)
2015/08/23 Javascript
页面间固定参数,通过cookie传值的实现方法
2017/05/31 Javascript
使用requirejs模块化开发多页面一个入口js的使用方式
2017/06/14 Javascript
vue cli4.0项目引入typescript的方法
2020/07/17 Javascript
使用vue构建多页面应用的示例
2020/10/22 Javascript
Vue-Ant Design Vue-普通及自定义校验实例
2020/10/24 Javascript
使用rpclib进行Python网络编程时的注释问题
2015/05/06 Python
django1.8使用表单上传文件的实现方法
2016/11/04 Python
神经网络(BP)算法Python实现及应用
2018/04/16 Python
对Python模块导入时全局变量__all__的作用详解
2019/01/11 Python
在python带权重的列表中随机取值的方法
2019/01/23 Python
详解用python实现基本的学生管理系统(文件存储版)(python3)
2019/04/25 Python
itchat-python搭建微信机器人(附示例)
2019/06/11 Python
pytorch使用tensorboardX进行loss可视化实例
2020/02/24 Python
HTMl5的存储方式sessionStorage和localStorage详解
2014/03/18 HTML / CSS
大专毕业生自我评价分享
2013/11/10 职场文书
新闻专业毕业生英文求职信
2014/03/19 职场文书
养生餐厅创业计划书范文
2014/03/26 职场文书
中学生演讲稿
2014/04/26 职场文书
医院科室评语
2015/01/04 职场文书
检讨书怎么写
2015/01/23 职场文书
项目投资意向书范本
2015/05/09 职场文书
怎么用Python识别手势数字
2021/06/07 Python