softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python输出当前目录下index.html文件路径的方法
Apr 28 Python
Python获取CPU、内存使用率以及网络使用状态代码
Feb 08 Python
Python 输入一个数字判断成绩分数等级的方法
Nov 15 Python
Python中三元表达式的几种写法介绍
Mar 04 Python
Python 多个图同时在不同窗口显示的实现方法
Jul 07 Python
Python基于BeautifulSoup和requests实现的爬虫功能示例
Aug 02 Python
Python文件路径名的操作方法
Oct 30 Python
python paramiko远程服务器终端操作过程解析
Dec 14 Python
使用tensorflow实现矩阵分解方式
Feb 07 Python
PyCharm GUI界面开发和exe文件生成的实现
Mar 04 Python
基于pandas向csv添加新的行和列
May 25 Python
pytorch中的model=model.to(device)使用说明
May 24 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
国内php原创论坛
2006/10/09 PHP
关于BIG5-HKSCS的解决方法
2007/03/20 PHP
php从数组中随机抽取一些元素的代码
2012/11/05 PHP
php 获取SWF动画截图示例代码
2014/02/10 PHP
php中mt_rand()随机数函数用法
2014/11/24 PHP
PHP实现的简单mock json脚本分享
2015/02/10 PHP
php文件读取方法实例分析
2015/06/20 PHP
PHP中的Trait 特性及作用
2016/04/03 PHP
PHP数据的提交与过滤基本操作实例详解
2016/11/11 PHP
tp5.1 框架路由操作-URL生成实例分析
2020/05/26 PHP
在IE,Firefox,Safari,Chrome,Opera浏览器上调试javascript
2008/12/02 Javascript
Javascript:为input设置readOnly属性(示例讲解)
2013/12/25 Javascript
随鼠标移动的时钟非常漂亮遗憾的是只支持IE
2014/08/12 Javascript
jquery移动点击的项目到列表最顶端的方法
2015/06/24 Javascript
功能强大的jquery.validate表单验证插件
2016/11/07 Javascript
基于Javascript实现的不重复ID的生成器
2016/12/25 Javascript
详解基于webpack和vue.js搭建开发环境
2017/04/05 Javascript
JavaScript内存泄漏的处理方式
2017/11/20 Javascript
vuex 使用文档小结篇
2018/01/11 Javascript
js+css实现红包雨效果
2018/07/12 Javascript
element-ui 上传图片后清空图片显示的实例
2018/09/04 Javascript
js实现列表向上无限滚动
2020/01/13 Javascript
用js编写留言板
2020/03/17 Javascript
[01:04:08]完美世界DOTA2联赛PWL S3 INK ICE vs GXR 第一场 12.16
2020/12/18 DOTA
浅谈python之自动化运维(Paramiko)
2020/01/31 Python
HTML5新增加标签和功能概述
2016/09/05 HTML / CSS
Expedia泰国:预订机票、酒店和旅游包(航班+酒店)
2016/09/27 全球购物
德国运动鞋网上商店:Afew Store
2018/01/05 全球购物
Moss Bros官网:英国排名第一的西装店
2020/02/26 全球购物
什么是serialVersionUID
2016/03/04 面试题
请介绍一下WSDL的文档结构
2013/03/17 面试题
党员自我评价分享
2013/12/13 职场文书
美术国培研修感言
2014/02/12 职场文书
授权委托书
2014/09/17 职场文书
防汛工作情况汇报
2014/10/28 职场文书
2016教师年度考核评语大全
2015/12/01 职场文书