softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的一个找零钱的小程序代码分享
Aug 25 Python
理解Python中的With语句
Mar 18 Python
Python实现可设置持续运行时间、线程数及时间间隔的多线程异步post请求功能
Jan 11 Python
详解tensorflow训练自己的数据集实现CNN图像分类
Feb 07 Python
python读取文本绘制动态速度曲线
Jun 21 Python
pandas读取csv文件,分隔符参数sep的实例
Dec 12 Python
在python中对变量判断是否为None的三种方法总结
Jan 23 Python
在python shell中运行python文件的实现
Dec 21 Python
Python递归函数特点及原理解析
Mar 04 Python
Python MOCK SERVER moco模拟接口测试过程解析
Apr 13 Python
python爬虫中的url下载器用法详解
Nov 30 Python
用 Django 开发一个 Python Web API的方法步骤
Dec 03 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
PHP strip_tags()去除HTML、XML以及PHP的标签介绍
2014/02/18 PHP
PHP中的Memcache详解
2014/04/05 PHP
Laravel框架基础语法与知识点整理【模板变量、输出、include引入子视图等】
2019/12/03 PHP
capacityFixed 基于jquery的类似于新浪微博新消息提示的定位框
2011/05/24 Javascript
对table和ul实现js分页示例分享
2014/02/24 Javascript
浅谈javascript的分号的使用
2015/05/12 Javascript
Javascript函数的参数
2015/07/16 Javascript
jQuery滚动新闻实现代码
2016/06/26 Javascript
jquery  实现轮播图详解及实例代码
2016/10/12 Javascript
vue中for循环更改数据的实例代码(数据变化但页面数据未变)
2017/09/15 Javascript
Three.JS实现三维场景
2018/12/30 Javascript
js中Generator函数的深入讲解
2019/04/07 Javascript
no-vnc和node.js实现web远程桌面的完整步骤
2019/08/11 Javascript
jquery制作的移动端购物车效果完整示例
2020/02/24 jQuery
在Vue中使用HOC模式的实现
2020/08/23 Javascript
Vue中强制组件重新渲染的正确方法
2021/01/03 Vue.js
vue实现简易计算器功能
2021/01/20 Vue.js
[01:01:18]VP vs NIP 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/17 DOTA
使用python获取CPU和内存信息的思路与实现(linux系统)
2014/01/03 Python
使用Python实现一个简单的项目监控
2015/03/31 Python
Python中的列表生成式与生成器学习教程
2016/03/13 Python
Windows下Python3.6安装第三方模块的方法
2018/11/22 Python
python3连接MySQL8.0的两种方式
2020/02/17 Python
如何用python免费看美剧
2020/08/11 Python
15款Python编辑器的优缺点,别再问我“选什么编辑器”啦
2020/10/19 Python
CSS3实现粒子旋转伸缩加载动画
2016/04/22 HTML / CSS
HTML5 embed 标签使用方法介绍
2013/08/13 HTML / CSS
酒店大堂副理的职责范文
2014/02/13 职场文书
幼儿园安全生产月活动总结
2014/07/05 职场文书
宾馆仓管员岗位职责
2014/07/27 职场文书
赔偿协议书范本
2014/09/12 职场文书
2014年重阳节敬老活动方案
2014/09/16 职场文书
行政申诉状范文
2015/05/20 职场文书
导游词之无锡华莱坞
2019/12/02 职场文书
Python基础知识之变量的详解
2021/04/14 Python
Jpa Specification如何实现and和or同时使用查询
2021/11/23 Java/Android