softmax及python实现过程解析


Posted in Python onSeptember 30, 2019

相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

  • 经典的交叉熵形式:L=-sum(y_right * log(y_pred)), 具体
  • 简单版本是: L = -Log(y_pred),具体

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是

正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist 

created on 2019.9.28
author: vince
'''
import math
import logging
import numpy 
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score

def loss_max_right_class_prob(predictions, y):
	return -predictions[numpy.argmax(y)];

def loss_cross_entropy(predictions, y):
	return -numpy.dot(y, numpy.log(predictions));
	
'''
Softmax classifier
linear classifier 
'''
class Softmax:

	def __init__(self, iter_num = 100000, batch_size = 1):
		self.__iter_num = iter_num;
		self.__batch_size = batch_size;
	
	def train(self, train_X, train_Y):
		X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
		Y = numpy.copy(train_Y);

		self.L = [];

		#initialize parameters
		self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
		self.__step_len = 1e-3; 

		logging.info("weight:%s" % (self.__weight));

		for iter_index in range(self.__iter_num):
			if iter_index % 1000 == 0:
				logging.info("-----iter:%s-----" % (iter_index));
			if iter_index % 100 == 0:
				l = 0;
				for i in range(0, len(X), 100):
					predictions = self.forward_pass(X[i]);
					#l += loss_max_right_class_prob(predictions, Y[i]); 
					l += loss_cross_entropy(predictions, Y[i]); 
				l /= len(X);
				self.L.append(l);

			sample_index = random.randint(0, len(X) - 1);
			logging.debug("-----select sample %s-----" % (sample_index));

			z = numpy.dot(X[sample_index], self.__weight);
			z = z - numpy.max(z);
			predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
#			dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1)); 
#			dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len;

			self.__weight -= dw;

			logging.debug("weight:%s" % (self.__weight));
			logging.debug("loss:%s" % (l));
		logging.info("weight:%s" % (self.__weight));
		logging.info("L:%s" % (self.L));
	
	def forward_pass(self, x):
		net = numpy.dot(x, self.__weight);
		net = net - numpy.max(net);
		net = numpy.exp(net) / numpy.sum(numpy.exp(net)); 
		return net;

	def predict(self, x):
		x = numpy.append(x, 1.0);
		return self.forward_pass(x);


def main():
	logging.basicConfig(level = logging.INFO,
			format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
			datefmt = '%a, %d %b %Y %H:%M:%S');
			
	logging.info("trainning begin.");

	mnist = read_data_sets('../data/MNIST',one_hot=True)  # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

	#load data
	train_X = mnist.train.images        #训练集样本
	validation_X = mnist.validation.images   #验证集样本
	test_X = mnist.test.images         #测试集样本
	#labels
	train_Y = mnist.train.labels        #训练集标签
	validation_Y = mnist.validation.labels   #验证集标签
	test_Y = mnist.test.labels         #测试集标签

	classifier = Softmax();
	classifier.train(train_X, train_Y);

	logging.info("trainning end. predict begin.");

	test_predict = numpy.array([]);
	test_right = numpy.array([]);
	for i in range(len(test_X)):
		predict_label = numpy.argmax(classifier.predict(test_X[i]));
		test_predict = numpy.append(test_predict, predict_label);
		right_label = numpy.argmax(test_Y[i]);
		test_right = numpy.append(test_right, right_label);

	logging.info("right:%s, predict:%s" % (test_right, test_predict));
	score = accuracy_score(test_right, test_predict);
	logging.info("The accruacy score is: %s "% (str(score)));


	plt.plot(classifier.L)
	plt.show();

if __name__ == "__main__":
	main();

损失函数收敛情况

softmax及python实现过程解析

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用Flask框架获取用户IP地址的方法
Mar 21 Python
Eclipse和PyDev搭建完美Python开发环境教程(Windows篇)
Nov 16 Python
Python字典操作详细介绍及字典内建方法分享
Jan 04 Python
Django Admin实现三级联动的示例代码(省市区)
Jun 22 Python
python使用matplotlib绘制热图
Nov 07 Python
利用Python模拟登录pastebin.com的实现方法
Jul 12 Python
解决python3 安装不了PIL的问题
Aug 16 Python
Python批量启动多线程代码实例
Feb 18 Python
Python小白垃圾回收机制入门
Jun 09 Python
Python StringIO及BytesIO包使用方法解析
Jun 15 Python
python实例化对象的具体方法
Jun 17 Python
python 三种方法实现对Excel表格的读写
Nov 19 Python
python根据时间获取周数代码实例
Sep 30 #Python
Win10 安装PyCharm2019.1.1(图文教程)
Sep 29 #Python
PyCharm2019安装教程及其使用(图文教程)
Sep 29 #Python
Python 文件操作之读取文件(read),文件指针与写入文件(write),文件打开方式示例
Sep 29 #Python
python3.7 利用函数os pandas利用excel对文件名进行归类
Sep 29 #Python
Python 多线程,threading模块,创建子线程的两种方式示例
Sep 29 #Python
Python 继承,重写,super()调用父类方法操作示例
Sep 29 #Python
You might like
一步一步学习PHP(8) php 数组
2010/03/05 PHP
php获取从百度搜索进入网站的关键词的详细代码
2014/01/08 PHP
PHP中new static()与new self()的比较
2016/08/19 PHP
删除PHP数组中的重复元素的实现代码
2017/04/10 PHP
jquery validate.js表单验证的基本用法入门
2010/05/13 Javascript
用jquery统计子菜单的条数示例代码
2013/10/18 Javascript
Node.js 异步编程之 Callback介绍(一)
2015/03/30 Javascript
javascript伸缩型菜单实现代码
2015/11/16 Javascript
jQuery实现验证年龄简单思路
2016/02/24 Javascript
深入理解jQuery之防止冒泡事件
2016/05/24 Javascript
bootstrap laydate日期组件使用详解
2017/01/04 Javascript
JS优化与惰性载入函数实例分析
2017/04/06 Javascript
jQuery Validate 校验多个相同name的方法
2017/05/18 jQuery
JS判断时间段的实现代码
2017/06/14 Javascript
5分钟打造简易高效的webpack常用配置
2017/07/04 Javascript
ReactNative页面跳转Navigator实现的示例代码
2017/08/02 Javascript
基于Cookie常用操作以及属性介绍
2017/09/07 Javascript
谈谈vue中mixin的一点理解
2017/12/12 Javascript
React native ListView 增加顶部下拉刷新和底下点击刷新示例
2018/04/27 Javascript
JavaScript中call和apply方法的区别实例分析
2018/08/03 Javascript
vue+element-ui实现表格编辑的三种实现方式
2018/10/31 Javascript
微信小程序合法域名配置方法
2019/05/06 Javascript
简单了解Ajax表单序列化的实现方法
2019/06/14 Javascript
6种JavaScript继承方式及优缺点(小结)
2020/02/06 Javascript
[04:40]2016个国际邀请赛中国区预选赛场地——华西村观战指南
2016/06/25 DOTA
在Django的模板中使用认证数据的方法
2015/07/23 Python
Pytorch中膨胀卷积的用法详解
2020/01/07 Python
Python 通过正则表达式快速获取电影的下载地址
2020/08/17 Python
全球最大的跑步用品商店:Road Runner Sports
2016/09/11 全球购物
第二层交换机和路由器的区别?第三层交换机和路由器的区别?
2013/05/23 面试题
Exception类的常用方法
2012/06/16 面试题
工程质量月活动方案
2014/02/19 职场文书
2014年开学第一课活动方案
2014/03/06 职场文书
基督教婚礼主持词
2014/03/14 职场文书
2014年机关党委工作总结
2014/12/11 职场文书
毕业设计致谢语
2015/05/14 职场文书