softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用 Monkey 命令操作屏幕快速滑动
Dec 07 Python
python3实现ftp服务功能(客户端)
Mar 24 Python
Python简单实现Base64编码和解码的方法
Apr 29 Python
Python绘制3d螺旋曲线图实例代码
Dec 20 Python
python实现手机通讯录搜索功能
Feb 22 Python
Python中fnmatch模块的使用详情
Nov 30 Python
python找出一个列表中相同元素的多个索引实例
Jun 11 Python
python 画函数曲线示例
Dec 04 Python
python设置环境变量的作用整理
Feb 17 Python
python 6.7 编写printTable()函数表格打印(完整代码)
Mar 25 Python
对django 2.x版本中models.ForeignKey()外键说明介绍
Mar 30 Python
python如何随机生成高强度密码
Aug 19 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
php urlencode()与urldecode()函数字符编码原理详解
2011/12/06 PHP
php通过正则表达式记取数据来读取xml的方法
2015/03/09 PHP
Symfony2使用第三方库Upload制作图片上传实例详解
2016/02/04 PHP
PHP和MySql中32位和64位的整形范围是多少
2016/02/18 PHP
解决PHP程序运行时:Fatal error: Maximum execution time of 30 seconds exceeded in的错误提示
2016/11/25 PHP
javascript插入样式实现代码
2012/02/22 Javascript
js设置组合快捷键/tabindex功能的方法
2013/11/21 Javascript
jquery.validate.js插件使用经验记录
2014/07/02 Javascript
jQuery实现高亮显示网页关键词的方法
2015/08/07 Javascript
原生js三级联动的简单实现代码
2016/06/07 Javascript
js实现登录注册框手机号和验证码校验(前端部分)
2017/09/28 Javascript
解决vuex刷新状态初始化的方法实现
2019/08/15 Javascript
python dict remove数组删除(del,pop)
2013/03/24 Python
用Python展示动态规则法用以解决重叠子问题的示例
2015/04/02 Python
Django内容增加富文本功能的实例
2017/10/17 Python
numpy排序与集合运算用法示例
2017/12/15 Python
python3 实现验证码图片切割的方法
2018/12/07 Python
在python中利用最小二乘拟合二次抛物线函数的方法
2018/12/29 Python
Python2与Python3的区别实例分析
2019/04/11 Python
Python中psutil的介绍与用法
2019/05/02 Python
使用python制作一个为hex文件增加版本号的脚本实例
2019/06/12 Python
python kafka 多线程消费者&手动提交实例
2019/12/21 Python
你可能不知道的Python 技巧小结
2020/01/29 Python
如何用python处理excel表格
2020/06/09 Python
Python爬虫实例——爬取美团美食数据
2020/07/15 Python
Python 实现图片转字符画的示例(静态图片,gif皆可)
2020/11/05 Python
CSS3 仿微信聊天小气泡实例代码
2017/04/05 HTML / CSS
VLAN和VPN有什么区别?分别实现在OSI的第几层?
2014/12/23 面试题
食堂员工工作职责
2013/12/18 职场文书
助人为乐道德模范事迹材料
2014/08/16 职场文书
酒店端午节活动方案
2014/08/26 职场文书
乡镇干部个人整改措施思想汇报
2014/10/10 职场文书
党支部2014年度工作总结
2014/12/04 职场文书
2015年读书月活动总结
2015/03/26 职场文书
Mysql超详细讲解死锁问题的理解
2022/04/01 MySQL
Python 视频画质增强
2022/04/28 Python