softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多线程编程(七):使用Condition实现复杂同步
Apr 05 Python
python实现FTP服务器服务的方法
Apr 11 Python
利用python爬取散文网的文章实例教程
Jun 18 Python
Python编程给numpy矩阵添加一列方法示例
Dec 04 Python
Python操作MySQL模拟银行转账
Mar 12 Python
TensorFlow的权值更新方法
Jun 14 Python
使用pandas批量处理矢量化字符串的实例讲解
Jul 10 Python
Python判断一个list中是否包含另一个list全部元素的方法分析
Dec 24 Python
Python requests模块实例用法
Feb 11 Python
python标识符命名规范原理解析
Jan 10 Python
Python3 实现爬取网站下所有URL方式
Jan 16 Python
Pycharm自带Git实现版本管理的方法步骤
Sep 18 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
PHP+MYSQL开发工具及资源收藏
2007/01/02 PHP
PHP 命令行参数详解及应用
2011/05/18 PHP
php权重计算方法代码分享
2014/01/09 PHP
php object转数组示例
2014/01/15 PHP
PHP dirname简单使用代码实例
2020/11/13 PHP
Javascript 闭包引起的IE内存泄露分析
2012/05/23 Javascript
JavaScript中json对象和string对象之间相互转化
2012/12/26 Javascript
JavaScript中的noscript元素属性位置及作用介绍
2013/04/11 Javascript
xmlhttp缓存清除的2种解决方法
2013/12/13 Javascript
jquery插件star-rating.js实现星级评分特效
2015/04/15 Javascript
JavaScript对象反射用法实例
2015/04/17 Javascript
JS+CSS实现大气的黑色首页导航菜单效果代码
2015/09/10 Javascript
基于AngularJs + Bootstrap + AngularStrap相结合实现省市区联动代码
2016/05/30 Javascript
JS检测页面中哪个HTML标签触发点击事件的方法
2016/06/17 Javascript
简单了解常用的JavaScript 库
2020/07/16 Javascript
浅谈javascript如何获取文件后缀名
2020/08/07 Javascript
解决Mint-ui 框架Popup和Datetime Picker组件滚动穿透的问题
2020/11/04 Javascript
[01:09]模型精美,特效酷炫!TI9不朽宝藏Ⅰ鉴赏
2019/05/10 DOTA
Python群发邮件实例代码
2014/01/03 Python
Python中字符串的常见操作技巧总结
2016/07/28 Python
Tensorflow 自带可视化Tensorboard使用方法(附项目代码)
2018/02/10 Python
python 列表中[ ]中冒号‘:’的作用
2019/04/30 Python
浅谈Python 敏感词过滤的实现
2019/08/15 Python
python yield和Generator函数用法详解
2020/02/10 Python
详解用Python爬虫获取百度企业信用中企业基本信息
2020/07/02 Python
一些高难度的SQL面试题
2016/11/29 面试题
环境科学毕业生自荐信
2013/11/21 职场文书
装修致歉信
2014/01/15 职场文书
物理专业本科生自荐信
2014/01/30 职场文书
医学类个人求职信范文
2014/02/05 职场文书
护校行动方案
2014/05/31 职场文书
学校周年庆活动方案
2014/08/22 职场文书
关于读书的演讲稿800字
2014/08/27 职场文书
小学运动会报道稿
2015/07/22 职场文书
浅谈css实现背景颜色半透明的两种方法
2021/12/06 HTML / CSS
vue里使用create, mounted调用方法
2022/04/26 Vue.js